{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 107,
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import numpy as np # get rid of this eventually\n",
    "import argparse\n",
    "from jax import jit\n",
    "from jax.experimental.ode import odeint\n",
    "from functools import partial # reduces arguments to function by making some subset implicit\n",
    "\n",
    "from jax.experimental import stax\n",
    "from jax.experimental import optimizers\n",
    "\n",
    "import os, sys, time\n",
    "sys.path.append('..')\n",
    "sys.path.append('../hyperopt')\n",
    "from HyperparameterSearch import extended_mlp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 108,
   "metadata": {},
   "outputs": [],
   "source": [
    "from models import mlp as make_mlp\n",
    "from utils import wrap_coords"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 109,
   "metadata": {},
   "outputs": [],
   "source": [
    "from jax.experimental.ode import odeint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 110,
   "metadata": {},
   "outputs": [],
   "source": [
    "def hamiltonian_eom(hamiltonian, state, conditionals, t=None):\n",
    "    q, p = jnp.split(state, 2)\n",
    "    q = q / 10.0 #Normalize\n",
    "    conditionals = conditionals / 10.0 #Normalize\n",
    "    q_t = jax.grad(hamiltonian, 1)(q, p, conditionals)\n",
    "    p_t = -jax.grad(hamiltonian, 0)(q, p, conditionals)\n",
    "    \n",
    "    q_tt = p_t #mass is 1\n",
    "    return jnp.concatenate([q_t, q_tt])\n",
    "\n",
    "# replace the lagrangian with a parameteric model\n",
    "def learned_dynamics(params, nn_forward_fn):\n",
    "    @jit\n",
    "    def dynamics(q, p, conditionals):\n",
    "        state = jnp.concatenate([q, p, conditionals])\n",
    "        return jnp.squeeze(nn_forward_fn(params, state), axis=-1)\n",
    "    return dynamics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 111,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ObjectView(object):\n",
    "    def __init__(self, d): self.__dict__ = d"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 112,
   "metadata": {},
   "outputs": [],
   "source": [
    "@jax.jit\n",
    "def qdotdot(q, q_t, conditionals):\n",
    "    g = conditionals\n",
    "    \n",
    "    q_tt = (\n",
    "        g * (1 - q_t**2)**(5./2) / \n",
    "        (1 + 2 * q_t**2)\n",
    "    )\n",
    "    \n",
    "    return q_t, q_tt\n",
    "\n",
    "@jax.jit\n",
    "def ofunc(y, t=None):\n",
    "    q = y[::3]\n",
    "    q_t = y[1::3]\n",
    "    g = y[2::3]\n",
    "    \n",
    "    q_t, q_tt = qdotdot(q, q_t, g)\n",
    "    return jnp.stack([q_t, q_tt, jnp.zeros_like(g)]).T.ravel()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 113,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DeviceArray(0.99989885, dtype=float32)"
      ]
     },
     "execution_count": 113,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(jnp.tanh(jax.random.uniform(jax.random.PRNGKey(1), (1000,))*10-5)*0.99999).max()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 114,
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 115,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([30.,  8.,  5.,  8.,  3.,  2.,  5.,  3.,  5., 31.]),\n",
       " array([-9.9991339e-01, -7.9992497e-01, -5.9993654e-01, -3.9994809e-01,\n",
       "        -1.9995967e-01,  2.8759241e-05,  2.0001718e-01,  4.0000561e-01,\n",
       "         5.9999406e-01,  7.9998249e-01,  9.9997091e-01], dtype=float32),\n",
       " <a list of 10 Patch objects>)"
      ]
     },
     "execution_count": 115,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAAD4CAYAAAD1jb0+AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAPx0lEQVR4nO3dfaxkd13H8feHLhQVtFt6qUth2RYr0sSwJTe1sQkP5amUhJZYdJuAi9YsIBiImLiAiWg0tkZoYiTgYmtXxfJQaLoKiEsfQkigeIul3bIpuy1VS9furaU8xFhp+/WPORfHu3N35t55uPuz71cymTO/8zvnfO9vZj975sw5M6kqJEntecJ6FyBJWhsDXJIaZYBLUqMMcElqlAEuSY3aMMuNnXTSSbVly5ZZblKSmnfLLbc8UFVzy9tnGuBbtmxhYWFhlpuUpOYl+ZdB7R5CkaRGGeCS1CgDXJIaZYBLUqMMcElqlAEuSY0ywCWpUQa4JDXKAJekRs30SkxJWk9bdn563bZ9z6Wvnvg63QOXpEYZ4JLUqKEBnuTJSb6S5GtJ7kjye137qUluTnIgyceSPGn65UqSloyyB/4wcG5VPR/YCpyX5GzgMuDyqjod+DZwyfTKlCQtNzTAq+f73cMndrcCzgWu6dp3AxdOpUJJ0kAjnYWS5DjgFuCngA8AdwEPVdUjXZd7gVNWWHYHsANg8+bNay70/9unx5I0rpE+xKyqR6tqK/BM4CzgeYO6rbDsrqqar6r5ubkjflBCkrRGqzoLpaoeAm4CzgZOSLK0B/9M4L7JliZJOppRzkKZS3JCN/0jwMuA/cCNwEVdt+3AddMqUpJ0pFGOgW8CdnfHwZ8AfLyq/j7J14GPJvkD4J+BK6ZYpyRpmaEBXlW3AWcOaL+b3vFwSdI68EpMSWqUAS5JjTLAJalRBrgkNcoAl6RGGeCS1CgDXJIaZYBLUqMMcElqlAEuSY0ywCWpUQa4JDXKAJekRhngktQoA1ySGmWAS1KjDHBJapQBLkmNMsAlqVEGuCQ1ygCXpEYZ4JLUKANckhplgEtSowxwSWrU0ABP8qwkNybZn+SOJG/v2t+b5FtJbu1u50+/XEnSkg0j9HkEeGdVfTXJU4Fbkuzt5l1eVX8yvfIkSSsZGuBVdQg41E1/L8l+4JRpFyZJOrpVHQNPsgU4E7i5a3pbktuSXJlk4wrL7EiykGRhcXFxrGIlSf9r5ABP8hTgk8A7quq7wAeB5wBb6e2hv2/QclW1q6rmq2p+bm5uAiVLkmDEAE/yRHrh/ZGq+hRAVd1fVY9W1WPAh4GzplemJGm5Uc5CCXAFsL+q3t/Xvqmv22uBfZMvT5K0klHOQjkHeANwe5Jbu7Z3Axcn2QoUcA/wpqlUKEkaaJSzUL4IZMCsz0y+HEnSqLwSU5IaZYBLUqMMcElqlAEuSY0ywCWpUQa4JDXKAJekRhngktQoA1ySGmWAS1KjDHBJapQBLkmNMsAlqVEGuCQ1ygCXpEYZ4JLUKANckhplgEtSowxwSWqUAS5JjTLAJalRBrgkNcoAl6RGGeCS1KihAZ7kWUluTLI/yR1J3t61n5hkb5ID3f3G6ZcrSVoyyh74I8A7q+p5wNnAW5OcAewErq+q04Hru8eSpBkZGuBVdaiqvtpNfw/YD5wCXADs7rrtBi6cVpGSpCOt6hh4ki3AmcDNwMlVdQh6IQ88fYVldiRZSLKwuLg4XrWSpB8aOcCTPAX4JPCOqvruqMtV1a6qmq+q+bm5ubXUKEkaYKQAT/JEeuH9kar6VNd8f5JN3fxNwOHplChJGmSUs1ACXAHsr6r3983aA2zvprcD102+PEnSSjaM0Occ4A3A7Ulu7dreDVwKfDzJJcC/Aq+bTomSpEGGBnhVfRHICrNfOtlyJEmj8kpMSWqUAS5JjTLAJalRBrgkNcoAl6RGGeCS1CgDXJIaZYBLUqMMcElqlAEuSY0ywCWpUQa4JDXKAJekRhngktQoA1ySGmWAS1KjDHBJapQBLkmNMsAlqVEGuCQ1ygCXpEYZ4JLUKANckhplgEtSo4YGeJIrkxxOsq+v7b1JvpXk1u52/nTLlCQtN8oe+FXAeQPaL6+qrd3tM5MtS5I0zNAAr6ovAA/OoBZJ0iqMcwz8bUlu6w6xbJxYRZKkkaw1wD8IPAfYChwC3rdSxyQ7kiwkWVhcXFzj5iRJy60pwKvq/qp6tKoeAz4MnHWUvruqar6q5ufm5tZapyRpmTUFeJJNfQ9fC+xbqa8kaTo2DOuQ5GrgxcBJSe4Ffhd4cZKtQAH3AG+aYo2SpAGGBnhVXTyg+Yop1CJJWgWvxJSkRhngktQoA1ySGmWAS1KjDHBJapQBLkmNMsAlqVEGuCQ1ygCXpEYZ4JLUKANckhplgEtSowxwSWqUAS5JjTLAJalRBrgkNcoAl6RGGeCS1CgDXJIaZYBLUqMMcElqlAEuSY0ywCWpUQa4JDXKAJekRg0N8CRXJjmcZF9f24lJ9iY50N1vnG6ZkqTlRtkDvwo4b1nbTuD6qjoduL57LEmaoaEBXlVfAB5c1nwBsLub3g1cOOG6JElDrPUY+MlVdQigu3/6Sh2T7EiykGRhcXFxjZuTJC039Q8xq2pXVc1X1fzc3Ny0NydJjxtrDfD7k2wC6O4PT64kSdIo1hrge4Dt3fR24LrJlCNJGtUopxFeDXwJeG6Se5NcAlwKvDzJAeDl3WNJ0gxtGNahqi5eYdZLJ1yLJGkVvBJTkhplgEtSowxwSWqUAS5JjTLAJalRBrgkNcoAl6RGDT0PXLBl56fXZbv3XPrqddnuev29sH5/s9Qi98AlqVEGuCQ1ygCXpEYZ4JLUKANckhplgEtSowxwSWqUAS5JjTLAJalRBrgkNcoAl6RGGeCS1CgDXJIaZYBLUqMMcElqlAEuSY0a6wcdktwDfA94FHikquYnUZQkabhJ/CLPS6rqgQmsR5K0Ch5CkaRGjbsHXsA/Jingz6tq1/IOSXYAOwA2b9485uYeX9bztyk1O4+331wFX9uTMu4e+DlV9QLgVcBbk7xweYeq2lVV81U1Pzc3N+bmJElLxgrwqrqvuz8MXAucNYmiJEnDrTnAk/xYkqcuTQOvAPZNqjBJ0tGNcwz8ZODaJEvr+duq+oeJVCVJGmrNAV5VdwPPn2AtkqRV8DRCSWqUAS5JjTLAJalRBrgkNcoAl6RGGeCS1CgDXJIaZYBLUqMm8X3g0sQ8Hr+Zb734jYDtcw9ckhplgEtSowxwSWqUAS5JjTLAJalRnoUi4RkZapN74JLUKANckhplgEtSowxwSWqUAS5JjTLAJalRBrgkNcoAl6RGGeCS1CgDXJIaNVaAJzkvyZ1JDibZOamiJEnDrTnAkxwHfAB4FXAGcHGSMyZVmCTp6MbZAz8LOFhVd1fVfwMfBS6YTFmSpGHG+TbCU4B/63t8L/Bzyzsl2QHs6B5+P8mda9zeScADa1x2mqxrdaxrdaxrdY7VushlY9X27EGN4wR4BrTVEQ1Vu4BdY2ynt7Fkoarmx13PpFnX6ljX6ljX6hyrdcF0ahvnEMq9wLP6Hj8TuG+8ciRJoxonwP8JOD3JqUmeBGwD9kymLEnSMGs+hFJVjyR5G/A54Djgyqq6Y2KVHWnswzBTYl2rY12rY12rc6zWBVOoLVVHHLaWJDXAKzElqVEGuCQ16pgK8CSvS3JHkseSrHi6zUqX8HcfqN6c5ECSj3Ufrk6irhOT7O3WuzfJxgF9XpLk1r7bfyW5sJt3VZJv9s3bOqu6un6P9m17T1/7eo7X1iRf6p7v25L8Ut+8iY7XsK98SHJ89/cf7MZjS9+8d3XtdyZ55Th1rKGu30zy9W58rk/y7L55A5/TGdX1xiSLfdv/tb5527vn/UCS7TOu6/K+mr6R5KG+edMcryuTHE6yb4X5SfKnXd23JXlB37zxxquqjpkb8DzgucBNwPwKfY4D7gJOA54EfA04o5v3cWBbN/0h4C0TquuPgZ3d9E7gsiH9TwQeBH60e3wVcNEUxmukuoDvr9C+buMF/DRwejf9DOAQcMKkx+tor5e+Pr8OfKib3gZ8rJs+o+t/PHBqt57jZljXS/peQ29Zqutoz+mM6noj8GcDlj0RuLu739hNb5xVXcv6/wa9EyumOl7dul8IvADYt8L884HP0rt25mzg5kmN1zG1B15V+6tq2JWaAy/hTxLgXOCart9u4MIJlXZBt75R13sR8Nmq+s8JbX8lq63rh9Z7vKrqG1V1oJu+DzgMzE1o+/1G+cqH/nqvAV7ajc8FwEer6uGq+iZwsFvfTOqqqhv7XkNfpnetxbSN8xUZrwT2VtWDVfVtYC9w3jrVdTFw9YS2fVRV9QV6O2wruQD4q+r5MnBCkk1MYLyOqQAf0aBL+E8BngY8VFWPLGufhJOr6hBAd//0If23ceSL5w+7t0+XJzl+xnU9OclCki8vHdbhGBqvJGfR26u6q695UuO10utlYJ9uPL5Db3xGWXaadfW7hN5e3JJBz+ks6/qF7vm5JsnSBX3HxHh1h5pOBW7oa57WeI1ipdrHHq9xLqVfkySfB35ywKz3VNV1o6xiQFsdpX3sukZdR7eeTcDP0js/fsm7gH+nF1K7gN8Gfn+GdW2uqvuSnAbckOR24LsD+q3XeP01sL2qHuua1zxegzYxoG353zmV19QQI687yeuBeeBFfc1HPKdVddeg5adQ198BV1fVw0neTO/dy7kjLjvNupZsA66pqkf72qY1XqOY2utr5gFeVS8bcxUrXcL/AL23Jhu6vahVXdp/tLqS3J9kU1Ud6gLn8FFW9YvAtVX1g751H+omH07yl8BvzbKu7hAFVXV3kpuAM4FPss7jleTHgU8Dv9O9tVxa95rHa4BRvvJhqc+9STYAP0HvLfE0vy5ipHUneRm9/xRfVFUPL7Wv8JxOIpCG1lVV/9H38MPAZX3LvnjZsjdNoKaR6uqzDXhrf8MUx2sUK9U+9ni1eAhl4CX81ftU4EZ6x58BtgOj7NGPYk+3vlHWe8Sxty7Elo47XwgM/LR6GnUl2bh0CCLJScA5wNfXe7y65+5aescGP7Fs3iTHa5SvfOiv9yLghm589gDb0jtL5VTgdOArY9SyqrqSnAn8OfCaqjrc1z7wOZ1hXZv6Hr4G2N9Nfw54RVffRuAV/N93olOtq6vtufQ+EPxSX9s0x2sUe4Bf7s5GORv4TreTMv54TeuT2bXcgNfS+1/pYeB+4HNd+zOAz/T1Ox/4Br3/Qd/T134avX9gB4FPAMdPqK6nAdcDB7r7E7v2eeAv+vptAb4FPGHZ8jcAt9MLor8BnjKruoCf77b9te7+kmNhvIDXAz8Abu27bZ3GeA16vdA7JPOabvrJ3d9/sBuP0/qWfU+33J3Aqyb8eh9W1+e7fwdL47Nn2HM6o7r+CLij2/6NwM/0Lfur3TgeBH5llnV1j98LXLpsuWmP19X0zqL6Ab38ugR4M/Dmbn7o/fjNXd325/uWHWu8vJRekhrV4iEUSRIGuCQ1ywCXpEYZ4JLUKANckhplgEtSowxwSWrU/wB/9oZOyVWQ9QAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.hist((jnp.tanh(jax.random.normal(jax.random.PRNGKey(1), (100,))*2)*0.99999))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 126,
   "metadata": {},
   "outputs": [],
   "source": [
    "@partial(jax.jit, static_argnums=(1, 2), backend='cpu')\n",
    "def gen_data(seed, batch, num):\n",
    "    rng = jax.random.PRNGKey(seed)\n",
    "    q0 = jax.random.uniform(rng, (batch,), minval=-10, maxval=10)\n",
    "    qt0 = jax.random.uniform(rng+1, (batch,), minval=-0.99, maxval=0.99)\n",
    "    g = jax.random.normal(rng+2, (batch,))*10\n",
    "\n",
    "    y0 = jnp.stack([q0, qt0, g]).T.ravel()\n",
    "\n",
    "    yt = odeint(ofunc, y0, jnp.linspace(0, 1, num=num), mxsteps=300)\n",
    "\n",
    "    qall = yt[:, ::3]\n",
    "    qtall = yt[:, 1::3]\n",
    "    gall = yt[:, 2::3]\n",
    "    \n",
    "    return jnp.stack([qall, qtall]).reshape(2, -1).T, gall.reshape(1, -1).T, qdotdot(qall, qtall, gall)[1].reshape(1, -1).T\n",
    "\n",
    "@partial(jax.jit, static_argnums=(1,))\n",
    "def gen_data_batch(seed, batch):\n",
    "    rng = jax.random.PRNGKey(seed)\n",
    "    q0 = jax.random.uniform(rng, (batch,), minval=-10, maxval=10)\n",
    "    qt0 = (jnp.tanh(jax.random.normal(jax.random.PRNGKey(1), (batch,))*2)*0.99999)#jax.random.uniform(rng+1, (batch,), minval=-1, maxval=1)\n",
    "    g = jax.random.normal(rng+2, (batch,))*10\n",
    "    \n",
    "    return jnp.stack([q0, qt0]).reshape(2, -1).T, g.reshape(1, -1).T, jnp.stack(qdotdot(q0, qt0, g)).T"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 129,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(DeviceArray([[ 1.4900093 ,  0.9574616 ],\n",
       "              [-8.006279  , -0.95951325],\n",
       "              [-2.1367955 , -0.21696056],\n",
       "              [ 7.883566  ,  0.6245134 ],\n",
       "              [ 1.9313316 ,  0.33272812]], dtype=float32),\n",
       " DeviceArray([[ 5.672981  ],\n",
       "              [13.351437  ],\n",
       "              [ 3.9785006 ],\n",
       "              [17.436966  ],\n",
       "              [-0.24577108]], dtype=float32),\n",
       " DeviceArray([[ 0.9574616 ,  0.00400571],\n",
       "              [-0.95951325,  0.00833027],\n",
       "              [-0.21696056,  3.2232604 ],\n",
       "              [ 0.6245134 ,  2.8466694 ],\n",
       "              [ 0.33272812, -0.15006456]], dtype=float32))"
      ]
     },
     "execution_count": 129,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cstate, cconditionals, ctarget = gen_data_batch(0, 5)\n",
    "cstate, cconditionals, ctarget"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 130,
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 131,
   "metadata": {},
   "outputs": [],
   "source": [
    "# qdotdot(jnp.array([0]), jnp.array([0.9]), jnp.array([10]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 132,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 0.29830917716026306 {'act': [4],\n",
    "# 'batch_size': [27.0], 'dt': [0.09609870774790222],\n",
    "# 'hidden_dim': [596.0], 'l2reg': [0.24927677946969878],\n",
    "# 'layers': [4.0], 'lr': [0.005516656601005163],\n",
    "# 'lr2': [1.897157209816416e-05], 'n_updates': [4.0]}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 133,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle as pkl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 134,
   "metadata": {},
   "outputs": [],
   "source": [
    "# loaded = pkl.load(open('./params_for_loss_0.29429444670677185_nupdates=1.pkl', 'rb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 135,
   "metadata": {},
   "outputs": [],
   "source": [
    "args = ObjectView({'dataset_size': 200,\n",
    " 'fps': 10,\n",
    " 'samples': 100,\n",
    " 'num_epochs': 80000,\n",
    " 'seed': 0,\n",
    " 'loss': 'l1',\n",
    " 'act': 'softplus',\n",
    " 'hidden_dim': 500,\n",
    " 'output_dim': 1,\n",
    " 'layers': 4,\n",
    " 'n_updates': 1,\n",
    " 'lr': 0.001,\n",
    " 'lr2': 2e-05,\n",
    " 'dt': 0.1,\n",
    " 'model': 'gln',\n",
    " 'batch_size': 68,\n",
    " 'l2reg': 5.7e-07,\n",
    "})\n",
    "# args = loaded['args']\n",
    "rng = jax.random.PRNGKey(args.seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 142,
   "metadata": {},
   "outputs": [],
   "source": [
    "from jax.experimental.ode import odeint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 143,
   "metadata": {},
   "outputs": [],
   "source": [
    "best_params = None\n",
    "best_loss = np.inf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 144,
   "metadata": {},
   "outputs": [],
   "source": [
    "from itertools import product"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 145,
   "metadata": {},
   "outputs": [],
   "source": [
    "init_random_params, nn_forward_fn = extended_mlp(args)\n",
    "rng = jax.random.PRNGKey(0)\n",
    "_, init_params = init_random_params(rng, (-1, 3))\n",
    "rng += 1\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is the output. Now, let's train it."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Idea: add identity before inverse:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Let's train it:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 146,
   "metadata": {},
   "outputs": [],
   "source": [
    "best_small_loss = np.inf\n",
    "iteration = 0\n",
    "total_epochs = 100\n",
    "minibatch_per = 3000\n",
    "train_losses, test_losses = [], []\n",
    "\n",
    "lr = 1e-3 #1e-3\n",
    "\n",
    "final_div_factor=1e4\n",
    "\n",
    "#OneCycleLR:\n",
    "@jax.jit\n",
    "def OneCycleLR(pct):\n",
    "    #Rush it:\n",
    "    start = 0.3 #0.2\n",
    "    pct = pct * (1-start) + start\n",
    "    high, low = lr, lr/final_div_factor\n",
    "    \n",
    "    scale = 1.0 - (jnp.cos(2 * jnp.pi * pct) + 1)/2\n",
    "    \n",
    "    return low + (high - low)*scale\n",
    "    \n",
    "\n",
    "opt_init, opt_update, get_params = optimizers.adam(\n",
    "    OneCycleLR\n",
    ")\n",
    "opt_state = opt_init(init_params)\n",
    "# opt_state = opt_init(best_params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 147,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Text(0.5, 1.0, 'lr schedule')"
      ]
     },
     "execution_count": 147,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAEICAYAAABcVE8dAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3deXxc5X3v8c9vZrSv1mpLtiwvso0BY0BgdkiAxlAcJ21DgWwFGkLuJYWmacO9aZP0Nnlx29w2996GhppA2BIIhDQxhCahScEEMNjG2HgBb9iWLNtarMXat6d/zJGRhWRkazTnaOb7fr300syjmXN+Ohp955nnLI855xARkcQX8rsAERGJDwW+iEiSUOCLiCQJBb6ISJJQ4IuIJAkFvohIklDgy5RkZnvN7KpJXscLZvanMVrWN8zssVg/VuRkKPBFRJKEAl8SjplF/K5BJIgU+DLleUMgPzGzx8ysDfiTUR5zrZltM7OjZnbAzL487GcrzexNM2szs91mtnzYU2eb2cve835tZkXDnneBmb1iZi1mtsnMrhj2szlm9qL3vOeB4c+7wsxqR9Q35hDVidYjcjIU+JIoVgI/AfKBH47y8weAzzvncoAzgN8CmNn5wCPAX3rPvQzYO+x5NwE3AyVAKvBl73nlwC+AbwIFXvvTZlbsPe9HwAaiQf93wGdP5Zcax3pExk2BL4niVefcz5xzg865rlF+3gcsNrNc51yzc+4Nr/1W4EHn3PPecw84594e9rwfOOd2eMt8EljqtX8KeM4595z3vOeB9cC1ZlYBnAf8jXOuxzm3BnjmFH+vMddzisuTJKbAl0RR8wE//0OiIbnPG2q50GufBew+wfMODbvdCWR7t2cDn/CGWVrMrAW4BJgBlAHNzrmOYc/dN87fY6QTrUfkpGjnliSKE1721Tm3DlhpZinAHUR767OIvlHMO4X11QCPOuc+N/IHZjYbmGZmWcNCv2JYjR1A5rDHh4GxhmjGXI/IyVIPXxKemaWa2SfNLM851we0AQPejx8AbjazK80sZGblZrZoHIt9DFhhZh8xs7CZpXs7Y2c65/YRHXb5W2/dlwArhj13B5BuZr/vvQH9NZB2sus5hU0hSU6BL8ni08Be7yie24mOjeOce53oTtnvAK3Ai0SHUU7IOVdDdEfx/wQaiPbE/5L3/qduApYBR4CvE90xPPTcVuC/Ad8HDhDt8R931M5JrEdk3EwToIiIJAf1EkREkoQCX0QkSSjwRUSShAJfRCRJBPo4/KKiIldZWel3GSIiU8qGDRsanXPvO7cj0IFfWVnJ+vXr/S5DRGRKMbNRz+zWkI6ISJJQ4IuIJAkFvohIklDgi4gkCQW+iEiSiFvgm9lpZnafNxXdF+K1XhERiRpX4JvZg2ZWb2ZbRrQvN7N3zGyXmd19omU457Y7524HrgeqT71kERE5FeM9Dv8h4LsMu8SrN2nDvcDVRC/tus7MVgNh4J4Rz7/FOVdvZh8F7vaWJXEwMOjo7O2nq3eAjt6BY7e7+wbpGxykr3+Q/kFH38AgvcNu9w1Ev/cPDDJ0QVWzoe92bPnH2jDMICUcIjUSIi0cIiVipIbDpEaibSlhIy0SIjUcJj0lRGZahOzUCJlpYVLCGl0UmWzjCnzn3BozqxzRfD6wyzm3B8DMngBWOufuAa4bYzmrgdVm9guikzy/j5ndBtwGUFFRMZ7yEppzjqM9/TS199LU3kNLZx+tXe99tXV734e1dfREg72zd4Ce/kG/f4VxSY2EyEoNk5UWISs1QlZa9HZ+Zir5GSnkZ6aQl5Fy3P1oWyoFWamEQ/bBKxFJchM507ac4+cRrSU64cOozOwK4A+Izuzz3FiPc86tAlYBVFdXJ+zF+vsGBjnc1s2h1m4OtnZzuK2bhvYemtp7afS+N7X30NjRS+8JQjsnLUJuRgq5GSnkZUSoLMwiOz0ampmpYTJSw2SlRshIDZOZGibTa0871uuOfkXCRuqw29F2IxIKEbL35uYb6u07HCOnUhh0jr5+R8/AAL390U8Jvf3RTw69AwP09jt6vU8S3X3RN6X2ngE6e/rp6B2go6efjt5+Onqib1Zt3f3UNnfR0tlLa1cfg2O8GkIGBVmpFGWnUZyTRrH3/dj9nDSm56VTlpdBRmp44n88kSlqIoE/WpdqzIB2zr0AvDCB9U0prZ197D/Syf4jndQ0d3rB3nUs4Bvae94XmClhoyg7jcLsaHgtKM2hKCeVoqw0inJSKchKY5rX083LSCE7LUIkaEMhqQApMV/s4GD0k05rZx8tXb20dPbR0tVHc0f0jbGhvYeGo700tPewp6GDhvaeUd8op2WmMCMvg7L8DMry0ynLz2BGXvR7RUEmJTlpxw1ZiSSSiQR+LdFJoIfMBOomVk6Uma0AVsyfPz8Wi5sUzjnqWrvZ29jB/iOd7GvqpMYL+P1HOmnt6jvu8dlpEWbkpTM9L52F03OYnpdx7H5ZXgbTc9PJzYgobMYQCtmxN7qK9+b/HpNzjrbufhrbe6hv6+FQWxd1Ld3UtXRxsLWb2uZOXn+3ibbu/uOel54SYnZBFrMLM72vLCoLo/fL8jM0dCRT2rinOPTG8J91zp3h3Y8QnYz5SqLzcq4DbnLObY1VcdXV1c7vi6f1DQyyr6mDXfXt7KpvZ3dDh/e9nc7egWOPSwkbM6dlMqsgk4qCaG+xoiCLioJMZhZkkJse+16vTFx7Tz8HW7o40NJFjffGvbepk/1HOtjX1HncPpCUsFFRkElVSQ5VpdnML8mmqiSHucVZpKdoqEiCw8w2OOfedzTkuHr4ZvY4cAVQZGa1wNedcw+Y2R3Ar4gemfNgrMLejx7+UI99e10b2w62sf1gGzsOH2VfUyf9wwaPy/LSmVeSzfXVs5hfks3c4ixmF2YxPTddvb8pKDstQlVpDlWlOe/72eCg4/DRbvY1dbKvqYO9TZ3saWhnR/1Rnt9+mAHvdREymF2Y5b0BZLOgNIfTy3KZW5yt14QESqAnMZ+sHn5v/yC76tvZdrCNbXXRcN92sO24YZg5RVks8Hpx80uymVcc/cpKC/QVpSVOevoH2NvYyc76o+w83H7s+7uNHcc6COkpIRZNz+X0slzOKM/j9LJcFpTm6NOATLqxevgJH/jOOfY2dfJmTTObalp5s6aFbXVt9A5EP6oP/VOeNiOXxWW5LJ6Rw8LpuWQr2OUU9A0Msruhna0H2tha18bWula2HWzjqLevIBIy5pdkc3pZHksr8jmnIp+FpTnB2/kuU9qUCvxhQzqf27lz50k//7U9Tby8q5GNNS1srm091nPPTA1zZnkeS2flc0Z5HqfNyGVOUZY+dsukcs5Rc6SLrXWtbKlrZWtdG1sOtNLY3gtARkqYJTPzOLtiGmdX5HN2RT4lOek+Vy1T2ZQK/CGn2sP/yk8289SGGhZOz2XprDzOmpnP0op8qkpyFO4SCM45apu7eGN/Mxv3t7CxpoVtda30DUT/H2cVZLBsTiHL5hRwwdxCZk7L0BFcMm5JFfgNR3vISoueZCQyVXT3DbC1ro2N+5tZt/cIr797hObO6KfTsrx0ls2NvgEsm1tIZWGm3gBkTFMq8Cc6pCOSCAYHHTvr23nt3SZe23OE195tOjYMVJ6fwaVVRVy2oJiL5xWRl6nDfuU9UyrwhwThOHyRoHDOsbuhg7V7mvjdzkZe3t3I0e5+QgZLZuZz2YJiLqsqYumsfO0ETnIKfJEE0z8wyKbaFtbsaGTNzgY21bQw6CAnPcKHFpZw9eJSLl9YrJP+kpACXyTBtXb28cruRn77dj2/fbuepo5eUsLGBXMLuXpxKVeeVkp5fobfZUocTKnA1xi+yMQMDDo27m/m+W2HeX77YfY0dABwelkuy0+fznVnlTGnKMvnKmWyTKnAH6Ievkhs7G5o5z+2HebX2w6zYV8zAGeU57JiSRnXnVWmnn+CUeCLCAAHW7v4xeaDPLP5IJtqWgA4d/Y0ViyZwbVLZuikrwSgwBeR99nf1Mkzm+t4ZlMdbx86Ssjgkqpirq+eyVWnleq6P1OUAl9ETmjn4aOs3lTH0xtqqWvtJi8jhZVLy/jEubM4ozxXJ3pNIVMq8LXTVsQ/A4OOV3Y38tT6Wn659RC9/YMsmp7DTcsq+PjZ5eToMM/Am1KBP0Q9fBF/tXb18cymOp5cX8Pm2lYyU8N8/OxyPn3hbBZNz/W7PBmDAl9EJmRTTQuPrd3H6k119PQPcn5lAZ+6cDbLT59OakRn9gaJAl9EYqK5o5enNtTw2Nr97D/SyfTcdG6+uJIbl1XorN6AUOCLSEwNDjpe3NHA/S/t4ZXdTWSnRbhpWQU3X1zJjDwd1+8nBb6ITJotB1pZtWYPv3jrIAZ89Kwybr9iHgtGmStYJt+UCnwdpSMyNdUc6eTBl9/lx+tq6Oob4NozZ3DXlVWjThIvk2dKBf4Q9fBFpqaWzl6+/9K7/ODld+nsG+C6JWXceeV85pco+ONBgS8icdfc0cv9L+3hoVf20tU3wIolZdx5VRXzirP9Li2hKfBFxDdHOnpZtWYPj7y6l57+QW44bxZ3XlWl6/ZMEgW+iPiuqb2Hf/7tLh5bu4/USIjPXzaPz102R/NPx9hYga+zJUQkbgqz0/jGR0/n+S9dzuULivnOf+zg8m+/wE821DI4GNzOZ6JQ4ItI3M0pyuJ7nzqXp79wETOnZfDlpzbxR/e9wpYDrX6XltAU+CLim3NnT+Pp2y/i23+0hP1HOlnx3d/x1z97i5bOXr9LS0gKfBHxVShkfKJ6Fr/5iyv4k4sqefz1Gj78jy/ys40HCPI+xqkokIFvZivMbFVrqz7eiSSLvIwUvr7idH7xZ5cwuzCTu378Jrc8tI66li6/S0sYOkpHRAJnYNDxyKt7+YdfvkPI4CvXLOJTy2YTCmkSlvHQUToiMmWEQ8bNF8/h139+GefMnsbXfr6Vm76/lgPq7U+IAl9EAmtWQSaP3HI+//CHS3irtpXl/3cNqzfV+V3WlKXAF5FAMzOuP28Wz915KVUl2fzZ4xu564mNtHb1+V3alKPAF5EpYXZhFk9+/kL+/KoFPLP5INf+v5fYuL/Z77KmFAW+iEwZkXCIO6+q4qnbL8QMrv/XV3n01b06fHOcFPgiMuWcUzGNZ794CZfML+Jvfr6VP//xm3T29vtdVuAp8EVkSsrPTOWBz57HX1y9gJ9vquNj977MnoZ2v8sKNAW+iExZoZDxxSureOSW82k42sPKe1/mlV2NfpcVWAp8EZnyLq0q5pkvXsKMvHQ+8+DrPLmuxu+SAimugW9mWWa2wcyui+d6RSTxzZyWyU++cBEXzivkr57ezP/+97d1yeURxhX4ZvagmdWb2ZYR7cvN7B0z22Vmd49jUV8BnjyVQkVEPkhuego/+JPzuGlZBfe9uJs7Hn+D7r4Bv8sKjPFOM/MQ8F3gkaEGMwsD9wJXA7XAOjNbDYSBe0Y8/xZgCbAN0JxmIjJpIuEQ3/rYGcwtyuJbz22nsf11vv/ZanLTU/wuzXfjCnzn3BozqxzRfD6wyzm3B8DMngBWOufuAd43ZGNmHwKygMVAl5k955wbnEDtIiKjMjP+9NK5lOSm86Ufv8lN96/l4ZvPpzA7ze/SfDWRMfxyYPiekVqvbVTOua865+4CfgTcP1bYm9ltZrbezNY3NDRMoDwRSXYfPauM+z9Tzc7D7XziX1/lYGtyX3xtIoE/2nVKP3APiXPuIefcsyf4+SrnXLVzrrq4uHgC5YmIwIcWlfDorcuob+vhhlVrOdTa7XdJvplI4NcCs4bdnwnE5DJ2mgBFRGLp/DkFPHzL+TS193Lj/Ws53JacoT+RwF8HVJnZHDNLBW4AVseiKOfcM8652/Ly8mKxOBERzp09jYdvOY/6tm5uXLWW+iQM/fEelvk48Cqw0MxqzexW51w/cAfwK2A78KRzbmssilIPX0Qmw7mzoz39w23dfPqB12ntTK5LLGuKQxFJOi/vauTmH6xjycw8Hr11GRmpYb9LiilNcSgi4rl4fhHf+eOlbNjfzB0/eoO+geQ4QjyQga8hHRGZbL+/ZAb/a+UZ/Obtev7mZ1uS4pr6gQx87bQVkXj49AWzueND83liXQ0P/O5dv8uZdOO9tIKISEL60tUL2N3Qzree286coiyuPK3U75ImTSB7+CIi8RIKGf94/VmcXpbLnz2+kbcPtfld0qQJZOBrDF9E4ikzNcL3P3MeWWkRbn90A23diXm4ZiADX2P4IhJv0/PSufeT51DT3MVfPbU5IXfiBjLwRUT8cF5lAXcvX8Qvtx7iwZf3+l1OzCnwRUSG+dNL5/B7i0u557ntbNjX7Hc5MRXIwNcYvoj4xcz49ifOoiw/gzuf2Eh7T7/fJcVMIANfY/gi4qe8jBT+6fqzqGvp4pvPbvO7nJgJZOCLiPiturKAz18+jyfW1fCb7Yf9LicmFPgiImO466oqFk3P4StPv8WRjl6/y5mwQAa+xvBFJAjSImG+88dLae3q5Ws/3+J3ORMWyMDXGL6IBMVpM3L54oereHbzQV7cMbXn2Q5k4IuIBMnnL5/L3KIsvvbzLXT3DfhdzilT4IuIfIC0SJhvfuwM9jV18i8v7Pa7nFOmwBcRGYeL5hexcmkZ972wmz0N7X6Xc0oU+CIi4/TV3z+NtJQQX18dk+m7406BLyIyTiU56dx5ZRUv7WzkpZ1TbwduIANfh2WKSFB9+sLZlOdn8Pe/fJvBwal1Rc1ABr4OyxSRoEqLhPnyRxaw5UAbz2yu87uckxLIwBcRCbKVZ5Vz2oxc/s+v36G3f9DvcsZNgS8icpJCIePuaxZRc6SLH762z+9yxk2BLyJyCi6rKuKieYXc+5+7pszJWAp8EZFTYGZ88cNVNLb38pMNtX6XMy4KfBGRU3TB3AKWzspn1Zo99A8EfyxfgS8icorMjNsvn8f+I538+5ZDfpfzgQIZ+DoOX0Smit9bXMrc4iy+98JunAv2cfmBDHwdhy8iU0UoFO3lbzvYxpqdjX6Xc0KBDHwRkankY0vLmZ6bzvde2OV3KSekwBcRmaDUSIibL65k7Z4jvHPoqN/ljEmBLyISA9dXzyI1EuKxtcE9EUuBLyISA9OyUlmxpIyfvlFLe0+/3+WMSoEvIhIjn75wNh29A/xs4wG/SxmVAl9EJEbOmpnHouk5PLW+xu9SRqXAFxGJETPjj86dyabaVnYcDt7OWwW+iEgMffzsciIhC+T1dRT4IiIxVJidxocXlfDTNw4E7vo6CnwRkRj7+NnlNLb38Nq7R/wu5ThxC3wzu8LMXjKz+8zsinitV0Qk3j60qISs1DDPbArWFIjjCnwze9DM6s1sy4j25Wb2jpntMrO7P2AxDmgH0oHgDW6JiMRIekqYqxeX8suthwI1BeJ4e/gPAcuHN5hZGLgXuAZYDNxoZovN7Ewze3bEVwnwknPuGuArwN/G7lcQEQme65aU0dLZx+92NfhdyjGR8TzIObfGzCpHNJ8P7HLO7QEwsyeAlc65e4DrTrC4ZiBtrB+a2W3AbQAVFRXjKU9EJHAuXVBETlqEX245xIcXlfpdDjCxMfxyYPjZBbVe26jM7A/M7F+BR4HvjvU459wq51y1c666uLh4AuWJiPgnLRLm8oXF/GZ7PQODwbhO/kQC30ZpG/O3cs791Dn3eefcHzvnXjjhgjUBiogkgKsXl9LU0cubNc1+lwJMLPBrgVnD7s8EYrJLWhOgiEgiuGJhCZGQ8fy2er9LASYW+OuAKjObY2apwA3A6tiUJSIy9eVlpLBsbgHPbwvGfLfjPSzzceBVYKGZ1ZrZrc65fuAO4FfAduBJ59zWWBSlIR0RSRRXLipld0MHNUc6/S5lfIHvnLvROTfDOZfinJvpnHvAa3/OObfAOTfPOfetWBWlIR0RSRSXLSgC4KUAzHerSyuIiEyiecXZzMhLD8Tx+IEMfA3piEiiMDMumV/Ey7uafD88M5CBryEdEUkkly4oprWrj7cO+NuJDWTgi4gkkovnFWIGv9vp77BOIANfQzoikkgKs9NYWJrj++WSAxn4GtIRkURzXmUBb+xr9nVSlEAGvohIoqmunEZH7wBvH/JvrlsFvohIHJxXWQDAur3+DesEMvA1hi8iiaYsP4Py/AzW7/XvQmqBDHyN4YtIIjqvchqv7z2Cc/4cjx/IwBcRSUTnVhbQcLSH2uYuX9avwBcRiZMl5dFRiy0+nYClwBcRiZNFM3JICRubFfjv0U5bEUlEaZEwC0pz1MMfTjttRSRRLZmZx+baVl923AYy8EVEEtUZ5Xm0dvX5suNWgS8iEkdnejtuN9fGf1hHgS8iEkcLpw/tuG2J+7oV+CIicZQWCTOvOJsdPlxTJ5CBr6N0RCSRLSjNYcfh9rivN5CBr6N0RCSRVZVkc6Cli46e/riuN5CBLyKSyKpKcwDYVR/fXr4CX0QkzhaUZgOw43B8x/EV+CIicVZRkElqOKQevohIoouEQ8wtzlIPX0QkGfhxpI4CX0TEB34cqaPAFxHxwbyS6I7bvU0dcVtnIANfJ16JSKKrKMgEoOZIZ9zWGcjA14lXIpLoZnmBvz/ZA19EJNHlZaSQl5GiwBcRSQYVBZnsPxK/6+Ir8EVEfFJRkEmtevgiIolvVkEmtc1dDAzGZ7pDBb6IiE8qCjLpHRjkcFt3XNanwBcR8UlFnI/UUeCLiPhEgS8ikiRm5KcTDlncTr5S4IuI+CQlHGJGXnrceviRuKwFMLMQ8HdALrDeOfdwvNYtIhJUM/LSg7XT1sweNLN6M9syon25mb1jZrvM7O4PWMxKoBzoA2pPrVwRkcRSkptO/dGeuKxrvEM6DwHLhzeYWRi4F7gGWAzcaGaLzexMM3t2xFcJsBB41Tn3JeALsfsVRESmrpKcNOrb4hP44xrScc6tMbPKEc3nA7ucc3sAzOwJYKVz7h7gupHLMLNaoNe7OzDWuszsNuA2gIqKivGUJyIyZZXmptPe009HTz9ZaZM7yj6RnbblQM2w+7Ve21h+CnzEzP4ZWDPWg5xzq5xz1c656uLi4gmUJyISfKW5aQBxGdaZyNuJjdI25vnBzrlO4NYJrE9EJOGU5KQDcLitmzlFWZO6ron08GuBWcPuzwTqJlZOlCZAEZFkMdTDj8eROhMJ/HVAlZnNMbNU4AZgdSyK0gQoIpIsir0efkMchnTGe1jm48CrwEIzqzWzW51z/cAdwK+A7cCTzrmtsShKPXwRSRa56RHSU0Jx6eGP9yidG8dofw54LqYVRZf7DPBMdXX152K9bBGRIDEzSuN0LL4urSAi4rOSnLTAj+FPGg3piEgyKclNj8vJV4EMfO20FZFkUpqjIR0RkaRQkptGe08/7T39k7qeQAa+hnREJJkcO9t2ksfxAxn4GtIRkWRSmBUN/CMdvR/wyIkJZOCLiCST/MwUAFo6+yZ1PQp8ERGfTctMBaC5Mwl7+BrDF5FkktQ9fI3hi0gyyU6LEAlZcvbwRUSSiZmRn5lCS1cS9vBFRJJNfmYqLerhi4gkvvyMFJo7krCHr522IpJs8jNTk3MMXzttRSTZTMtMoVVj+CIiiW9aVpL28EVEkk1eRgrdfYN09w1M2joU+CIiARCPs20V+CIiATAtDmfbBjLwdZSOiCSbPC/wk66Hr6N0RCTZDA3pJF0PX0Qk2WgMX0QkScTjipkKfBGRAEhPCZOeEprU6+ko8EVEAmJaZirN6uGLiCS+yb5ipgJfRCQgCrNSaWxX4IuIJLyi7FSaOnombfmBDHydeCUiyagoO43Go0nWw9eJVyKSjIpy0ujqG6Cjp39Slh/IwBcRSUaFWdGTrxrbJ2dYR4EvIhIQRTlpgAJfRCThFWdHA79hksbxFfgiIgFR5AX+ZB2po8AXEQmIwmxvDF89fBGRxJYSDpGfmaIxfBGRZFCUnabAFxFJBkXZqQp8EZFkUJidNmnX01Hgi4gESPEkDulEJmWpozCzS4FPeutc7Jy7KF7rFhGZKoqyUzna3U933wDpKeGYLntcPXwze9DM6s1sy4j25Wb2jpntMrO7T7QM59xLzrnbgWeBh0+9ZBGRxPXesfixH9YZ75DOQ8Dy4Q1mFgbuBa4BFgM3mtliMzvTzJ4d8VUy7Kk3AY/HoHYRkYQzFPiNR2M/rDOuIR3n3BozqxzRfD6wyzm3B8DMngBWOufuAa4bbTlmVgG0OufaxlqXmd0G3AZQUVExnvJERBJGZVEm1545PebDOTCxnbblQM2w+7Ve24ncCvzgRA9wzq1yzlU756qLi4snUJ6IyNQzvySHf/nkuSycnhPzZU9kp62N0uZO9ATn3NfHtWCzFcCK+fPnn0pdIiIyion08GuBWcPuzwTqJlZOlCZAERGJvYkE/jqgyszmmFkqcAOwOjZliYhIrI33sMzHgVeBhWZWa2a3Ouf6gTuAXwHbgSedc1tjUZTmtBURiT1z7oTD7r6qrq5269ev97sMEZEpxcw2OOeqR7br0goiIkkikIGvIR0RkdgLZODrKB0RkdgL9Bi+mTUA+07x6UVAYwzLiZWg1gXBrU11nZyg1gXBrS3R6prtnHvfmauBDvyJMLP1o+208FtQ64Lg1qa6Tk5Q64Lg1pYsdQVySEdERGJPgS8ikiQSOfBX+V3AGIJaFwS3NtV1coJaFwS3tqSoK2HH8EVE5HiJ3MMXEZFhFPgiIkkiIQP/ZObaneQ6ZpnZf5rZdjPbamZ3eu3fMLMDZvam93WtD7XtNbO3vPWv99oKzOx5M9vpfZ8W55oWDtsmb5pZm5nd5df2Gm0u57G2kUX9f+81t9nMzolzXd82s7e9df+bmeV77ZVm1jVs290X57rG/NuZ2f/wttc7ZvaRONf142E17TWzN732eG6vsfJh8l5jzrmE+gLCwG5gLpAKbAIW+1TLDOAc73YOsIPo/L/fAL7s83baCxSNaPsH4G7v9t3A3/v8dzwEzPZrewGXAecAWz5oGwHXAv9OdGKgC4DX4lzX7wER7/bfD6urcvjjfNheo/7tvP+DTUAaMMf7nw3Hq64RP/9H4Gs+bK+x8mHSXmOJ2MM/Nteuc64XeAJY6UchzrmDzrk3vNtHiV5G+oOmgfTTSqokAsMAAAMDSURBVOBh7/bDwMd8rOVKYLdz7lTPtJ4w59wa4MiI5rG20UrgERe1Fsg3sxnxqss592sXvWQ5wFqiExLF1RjbaywrgSeccz3OuXeBXUT/d+Nal5kZcD3w+GSs+0ROkA+T9hpLxMA/lbl2J51FJ4E/G3jNa7rD+1j2YLyHTjwO+LWZbbDoxPEApc65gxB9MQIlPtQ15AaO/yf0e3sNGWsbBel1dwvRnuCQOWa20cxeNLNLfahntL9dULbXpcBh59zOYW1x314j8mHSXmOJGPgnPdfuZDOzbOBp4C7nXBvwPWAesBQ4SPQjZbxd7Jw7B7gG+O9mdpkPNYzKojOofRR4ymsKwvb6IIF43ZnZV4F+4Ide00Ggwjl3NvAl4EdmlhvHksb62wViewE3cnzHIu7ba5R8GPOho7Sd1DZLxMCftLl2T4WZpRD9Y/7QOfdTAOfcYefcgHNuELifSfooeyLOuTrvez3wb14Nh4c+Inrf6+Ndl+ca4A3n3GGvRt+31zBjbSPfX3dm9lngOuCTzhv09YZMmrzbG4iOlS+IV00n+NsFYXtFgD8AfjzUFu/tNVo+MImvsUQM/MDMteuNDz4AbHfO/dOw9uHjbh8Htox87iTXlWVmOUO3ie7w20J0O33We9hngZ/Hs65hjut1+b29RhhrG60GPuMdSXEB0Dr0sTwezGw58BXgo865zmHtxWYW9m7PBaqAPXGsa6y/3WrgBjNLM7M5Xl2vx6suz1XA28652qGGeG6vsfKByXyNxWNvdLy/iO7N3kH03fmrPtZxCdGPXJuBN72va4FHgbe89tXAjDjXNZfoERKbgK1D2wgoBH4D7PS+F/iwzTKBJiBvWJsv24vom85BoI9o7+rWsbYR0Y/b93qvubeA6jjXtYvo+O7Q6+w+77F/6P2NNwFvACviXNeYfzvgq972ege4Jp51ee0PAbePeGw8t9dY+TBprzFdWkFEJEkk4pCOiIiMQoEvIpIkFPgiIklCgS8ikiQU+CIiSUKBLyKSJBT4IiJJ4r8ArUnT0Hilux0AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(OneCycleLR(jnp.linspace(0, 1, num=200)))\n",
    "plt.yscale('log')\n",
    "plt.title('lr schedule')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 152,
   "metadata": {},
   "outputs": [],
   "source": [
    "from jax.tree_util import tree_flatten\n",
    "\n",
    "@jax.jit\n",
    "def loss(params, cstate, cconditionals, ctarget):\n",
    "    runner = jax.vmap(\n",
    "        partial(\n",
    "            hamiltonian_eom,\n",
    "            learned_dynamics(params, nn_forward_fn)), (0, 0), 0)\n",
    "    preds = runner(cstate, cconditionals)[:, [0, 1]]\n",
    "    \n",
    "    error = jnp.abs(preds - ctarget)\n",
    "    #Weight additionally by proximity to c!\n",
    "    error_weights = (1 + 1/jnp.sqrt(1.0-cstate[:, [1]]**2))\n",
    "    \n",
    "    return jnp.sum(error * error_weights)*len(preds)/jnp.sum(error_weights)\n",
    "\n",
    "@jax.jit\n",
    "def update_derivative(i, opt_state, cstate, cconditionals, ctarget):\n",
    "    params = get_params(opt_state)\n",
    "    param_update = jax.grad(\n",
    "            lambda *args: loss(*args)/len(cstate),\n",
    "            0\n",
    "        )(params, cstate, cconditionals, ctarget)\n",
    "    params = get_params(opt_state)\n",
    "    return opt_update(i, param_update, opt_state), params\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 153,
   "metadata": {},
   "outputs": [],
   "source": [
    "epoch = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 154,
   "metadata": {},
   "outputs": [],
   "source": [
    "cstate, cconditionals, ctarget = gen_data_batch(epoch, 128)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 155,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DeviceArray(175.4566, dtype=float32)"
      ]
     },
     "execution_count": 155,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "loss(get_params(opt_state), cstate, cconditionals, ctarget)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 156,
   "metadata": {},
   "outputs": [],
   "source": [
    "update_derivative(0, opt_state, cstate, cconditionals, ctarget);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 157,
   "metadata": {},
   "outputs": [],
   "source": [
    "rng = jax.random.PRNGKey(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 158,
   "metadata": {},
   "outputs": [],
   "source": [
    "epoch = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 159,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm.notebook import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 160,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(128, 2)"
      ]
     },
     "execution_count": 160,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gen_data_batch(0, 128)[0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 161,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DeviceArray([[ -7.3825674 ],\n",
       "             [  0.48637673],\n",
       "             [  5.567146  ],\n",
       "             [  3.391427  ],\n",
       "             [-11.220732  ]], dtype=float32)"
      ]
     },
     "execution_count": 161,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cconditionals[:5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 162,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DeviceArray([[-8.077726  ,  0.13323684],\n",
       "             [-2.589147  , -0.12181558],\n",
       "             [-8.912823  ,  0.9635672 ],\n",
       "             [ 4.5006466 ,  0.9999864 ],\n",
       "             [ 6.4032707 ,  0.98424286]], dtype=float32)"
      ]
     },
     "execution_count": 162,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cstate[:5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 163,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DeviceArray([[ 1.3323684e-01, -6.8172374e+00],\n",
       "             [-1.2181558e-01,  4.5502922e-01],\n",
       "             [ 9.6356720e-01,  2.6673502e-03],\n",
       "             [ 9.9998641e-01,  4.3539304e-12],\n",
       "             [ 9.8424286e-01, -6.6027965e-04]], dtype=float32)"
      ]
     },
     "execution_count": 163,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ctarget[:5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 164,
   "metadata": {},
   "outputs": [],
   "source": [
    "best_loss = np.inf\n",
    "best_params = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 165,
   "metadata": {},
   "outputs": [],
   "source": [
    "from copy import deepcopy as copy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 166,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "00151e8540f84c909ceb40a433194596",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=0 lr=0.0006752928602509201 loss=0.4576779007911682\n",
      "epoch=1 lr=0.0006957106525078416 loss=0.4278222322463989\n",
      "epoch=2 lr=0.0007157498621381819 loss=0.42640888690948486\n",
      "epoch=3 lr=0.0007353719556704164 loss=0.4248338043689728\n",
      "epoch=4 lr=0.0007545390981249511 loss=0.4242565929889679\n",
      "epoch=5 lr=0.000773213803768158 loss=0.4227651357650757\n",
      "epoch=6 lr=0.0007913602748885751 loss=0.42145994305610657\n",
      "epoch=7 lr=0.0008089432376436889 loss=0.42091017961502075\n",
      "epoch=8 lr=0.0008259288733825088 loss=0.4207681119441986\n",
      "epoch=9 lr=0.0008422840619459748 loss=0.41996580362319946\n",
      "epoch=10 lr=0.0008579773711971939 loss=0.4197915494441986\n",
      "epoch=11 lr=0.000872978416737169 loss=0.41876354813575745\n",
      "epoch=12 lr=0.0008872582111507654 loss=0.41935107111930847\n",
      "epoch=13 lr=0.0009007891057990491 loss=0.41843855381011963\n",
      "epoch=14 lr=0.0009135449072346091 loss=0.4186243414878845\n",
      "epoch=15 lr=0.0009255009354092181 loss=0.4181230068206787\n",
      "epoch=16 lr=0.0009366340818814933 loss=0.41830745339393616\n",
      "epoch=17 lr=0.000946922751609236 loss=0.41750913858413696\n",
      "epoch=18 lr=0.0009563472121953964 loss=0.4169684648513794\n",
      "epoch=19 lr=0.0009648891282267869 loss=0.4164976477622986\n",
      "epoch=20 lr=0.0009725319687277079 loss=0.4160488247871399\n",
      "epoch=21 lr=0.0009792608907446265 loss=0.4161275327205658\n",
      "epoch=22 lr=0.000985063030384481 loss=0.4145805537700653\n",
      "epoch=23 lr=0.0009899272117763758 loss=0.4153549373149872\n",
      "epoch=24 lr=0.0009938437724485993 loss=0.41514167189598083\n",
      "epoch=25 lr=0.000996805145405233 loss=0.4160304367542267\n",
      "epoch=26 lr=0.0009988059755414724 loss=0.4155493378639221\n",
      "epoch=27 lr=0.0009998419554904103 loss=0.4141225814819336\n",
      "epoch=28 lr=0.0009999113390222192 loss=0.41487812995910645\n",
      "epoch=29 lr=0.0009990138933062553 loss=0.4153560698032379\n",
      "epoch=30 lr=0.000997151480987668 loss=0.4150068461894989\n",
      "epoch=31 lr=0.000994327594526112 loss=0.4148404002189636\n",
      "epoch=32 lr=0.0009905475890263915 loss=0.41430777311325073\n",
      "epoch=33 lr=0.000985819031484425 loss=0.41456151008605957\n",
      "epoch=34 lr=0.0009801508858799934 loss=0.41460469365119934\n",
      "epoch=35 lr=0.0009735542116686702 loss=0.41374877095222473\n",
      "epoch=36 lr=0.0009660416981205344 loss=0.41455087065696716\n",
      "epoch=37 lr=0.0009576278389431536 loss=0.4142729341983795\n",
      "epoch=38 lr=0.0009483289322815835 loss=0.4137929379940033\n",
      "epoch=39 lr=0.0009381631389260292 loss=0.4138018786907196\n",
      "epoch=40 lr=0.0009271499002352357 loss=0.41362735629081726\n",
      "epoch=41 lr=0.0009153104037977755 loss=0.4134262204170227\n",
      "epoch=42 lr=0.000902667990885675 loss=0.41215479373931885\n",
      "epoch=43 lr=0.000889246875885874 loss=0.41274869441986084\n",
      "epoch=44 lr=0.0008750728447921574 loss=0.4129306972026825\n",
      "epoch=45 lr=0.0008601734298281372 loss=0.41323956847190857\n",
      "epoch=46 lr=0.00084457773482427 loss=0.41251224279403687\n",
      "epoch=47 lr=0.0008283156785182655 loss=0.4119430184364319\n",
      "epoch=48 lr=0.0008114183438010514 loss=0.41314443945884705\n",
      "epoch=49 lr=0.0007939189672470093 loss=0.4127974808216095\n",
      "epoch=50 lr=0.0007758514257147908 loss=0.41186776757240295\n",
      "epoch=51 lr=0.0007572502945549786 loss=0.4118669033050537\n",
      "epoch=52 lr=0.000738151662517339 loss=0.4118064045906067\n",
      "epoch=53 lr=0.0007185924332588911 loss=0.41286376118659973\n",
      "epoch=54 lr=0.0006986107910051942 loss=0.4120926558971405\n",
      "epoch=55 lr=0.0006782450363971293 loss=0.4117642343044281\n",
      "epoch=56 lr=0.0006575344013981521 loss=0.41221532225608826\n",
      "epoch=57 lr=0.0006365194567479193 loss=0.4109550714492798\n",
      "epoch=58 lr=0.0006152403657324612 loss=0.4106297194957733\n",
      "epoch=59 lr=0.0005937386304140091 loss=0.41195157170295715\n",
      "epoch=60 lr=0.0005720554618164897 loss=0.4116988480091095\n",
      "epoch=61 lr=0.0005502330604940653 loss=0.4107952117919922\n",
      "epoch=62 lr=0.0005283138016238809 loss=0.41156187653541565\n",
      "epoch=63 lr=0.000506339652929455 loss=0.4114292860031128\n",
      "epoch=64 lr=0.00048435357166454196 loss=0.4120069146156311\n",
      "epoch=65 lr=0.0004623976710718125 loss=0.41039448976516724\n",
      "epoch=66 lr=0.00044051450095139444 loss=0.4108618199825287\n",
      "epoch=67 lr=0.0004187468148302287 loss=0.41074541211128235\n",
      "epoch=68 lr=0.0003971360856667161 loss=0.4097667932510376\n",
      "epoch=69 lr=0.00037572436849586666 loss=0.4103347659111023\n",
      "epoch=70 lr=0.00035455342731438577 loss=0.4110722243785858\n",
      "epoch=71 lr=0.0003336636000312865 loss=0.4105161130428314\n",
      "epoch=72 lr=0.00031309586483985186 loss=0.4096032381057739\n",
      "epoch=73 lr=0.0002928894537035376 loss=0.41041135787963867\n",
      "epoch=74 lr=0.0002730839769355953 loss=0.4102528989315033\n",
      "epoch=75 lr=0.0002537174441386014 loss=0.4103509187698364\n",
      "epoch=76 lr=0.00023482716642320156 loss=0.41093015670776367\n",
      "epoch=77 lr=0.00021644984371960163 loss=0.4100760817527771\n",
      "epoch=78 lr=0.0001986212591873482 loss=0.4101187288761139\n",
      "epoch=79 lr=0.00018137549341190606 loss=0.409525066614151\n",
      "epoch=80 lr=0.0001647462631808594 loss=0.4091007709503174\n",
      "epoch=81 lr=0.0001487653498770669 loss=0.4087952971458435\n",
      "epoch=82 lr=0.00013346386549528688 loss=0.4100150167942047\n",
      "epoch=83 lr=0.00011887160508194938 loss=0.4099412262439728\n",
      "epoch=84 lr=0.00010501645738258958 loss=0.4089011251926422\n",
      "epoch=85 lr=9.192524885293096e-05 loss=0.41001367568969727\n",
      "epoch=86 lr=7.96236636233516e-05 loss=0.4094568192958832\n",
      "epoch=87 lr=6.813512300141156e-05 loss=0.40959182381629944\n",
      "epoch=88 lr=5.748197509092279e-05 loss=0.4094073176383972\n",
      "epoch=89 lr=4.768490543938242e-05 loss=0.40952086448669434\n",
      "epoch=90 lr=3.876250411849469e-05 loss=0.4091009199619293\n",
      "epoch=91 lr=3.073247353313491e-05 loss=0.40920308232307434\n",
      "epoch=92 lr=2.3610193238710053e-05 loss=0.40987756848335266\n",
      "epoch=93 lr=1.740936750138644e-05 loss=0.409124493598938\n",
      "epoch=94 lr=1.2142033483542036e-05 loss=0.4094020128250122\n",
      "epoch=95 lr=7.818327503628097e-06 loss=0.40913116931915283\n",
      "epoch=96 lr=4.446651473699603e-06 loss=0.40973106026649475\n",
      "epoch=97 lr=2.0336199213488726e-06 loss=0.4090292155742645\n",
      "epoch=98 lr=5.838221568410518e-07 loss=0.40985363721847534\n",
      "epoch=99 lr=1.0000000116860974e-07 loss=0.4095536172389984\n",
      "\n"
     ]
    }
   ],
   "source": [
    "for epoch in tqdm(range(epoch, total_epochs)):\n",
    "    epoch_loss = 0.0\n",
    "    num_samples = 0\n",
    "    batch = 512\n",
    "    ocstate, occonditionals, octarget = gen_data_batch(epoch, minibatch_per*batch)\n",
    "    for minibatch in range(minibatch_per):\n",
    "        fraction = (epoch + minibatch/minibatch_per)/total_epochs\n",
    "        s = np.s_[minibatch*batch:(minibatch+1)*batch]\n",
    "        \n",
    "        cstate, cconditionals, ctarget = ocstate[s], occonditionals[s], octarget[s]\n",
    "        opt_state, params = update_derivative(fraction, opt_state, cstate, cconditionals, ctarget);\n",
    "        rng += 10\n",
    "        \n",
    "        cur_loss = loss(params, cstate, cconditionals, ctarget)\n",
    "        \n",
    "        epoch_loss += cur_loss\n",
    "        num_samples += len(cstate)\n",
    "    closs = epoch_loss/num_samples\n",
    "    print('epoch={} lr={} loss={}'.format(\n",
    "        epoch, OneCycleLR(fraction), closs)\n",
    "         )\n",
    "    if closs < best_loss:\n",
    "        best_loss = closs\n",
    "        best_params = [[copy(jax.device_get(l2)) for l2 in l1] if len(l1) > 0 else () for l1 in params]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 167,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle as pkl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 169,
   "metadata": {},
   "outputs": [],
   "source": [
    "# pkl.dump({'params': best_params, 'description': 'q and g are divided by 10. hidden=500. act=Softplus'},\n",
    "#          open('best_sr_params_hamiltonian.pkl', 'wb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 170,
   "metadata": {},
   "outputs": [],
   "source": [
    "opt_state = opt_init(best_params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 171,
   "metadata": {},
   "outputs": [],
   "source": [
    "cstate, cconditionals, ctarget = gen_data(0, 1, 50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 172,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(50, 2)"
      ]
     },
     "execution_count": 172,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cstate.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 173,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x2aac607dae50>]"
      ]
     },
     "execution_count": 173,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAeAElEQVR4nO3deXRc5Znn8e+jfV8syZYteQUDXgADGrOFhCUQQ9KQhUwgdBISgtMnIUtP0n0gJ02nSeck3WdOkzkZMjME0qEzaQhNEuJm6GYJ0Ek6BixjMFhe8IJtLZZkLaWtqrQ980eVTSFkq2xLLtet3+ecOlX31nXpuXbp5/e8933fa+6OiIikv6xUFyAiItNDgS4iEhAKdBGRgFCgi4gEhAJdRCQgclL1g6urq33RokWp+vEiImlp48aNB929ZrL3UhboixYtorGxMVU/XkQkLZnZ3iO9py4XEZGAUKCLiASEAl1EJCAU6CIiAaFAFxEJCAW6iEhAKNBFRAIiZePQRUSCxN2Jjo4zEB1lIDLKQHSUwWjsOfZ67PD2lWfN5tz5FdNegwJdRDJedHSM/sho/DHCQGSUvkOvo7H9ic8DkZF37BscjoX46Hhy95eoKc1XoIuITOTuREbG6YuM0BceoS8yQig8Ql94NGFfLJwT9/XHQ7svMsLw6PiUPyc/J4vSglxKC3IoyY895s8qojQ/h5L4vuL8nMPvF+fnUBp/Ls4/tC+b4rwcsrJsRv4uFOgickoYGRund2iEUHg4/jzy9nM4FsK9Q8OEwiMJj1H6wiMMjx09kAtysygryKWsMBbI5UV5sTAuyKWsIBbCh94rzc+lJL6vND+2rzg/h7ycU/+SowJdRKaVuzM4PEbP4DA9Q8N0D8YCOvY8TM/QCD1DsX294WF6BmPhPBAdPeJnmkFpfg7lRbmUF8YeteUFlBfmUV6YS1lhTuy5IDe+HQvqQyGdn5N9Ev8GUkeBLiJHNT7u9IZH6B6McnBgmK6BYboHo3QNDtMzOEzXYCy0u+MB3jN45BazGZQX5lJZlEdFUS6zSws4Y3YpFfHtinhgVxTlUVH49nZpQS7ZM9RNESQKdJEMNDw6zsGBKJ39UboGoxzsH6ZzIErXwDAHB6IcHIjSPTjMwYFYSI8d4WJfWUEOVSX5zCrOo76yiHPrK6gszmNWcSy0ZxXnUVGUR2VRbLusUME8kxToIgHh7vSFR+noj9DRH6WjP0Jnf5SOvigd/dHDAd45EKV3aGTSzyjJz6G6JI+qknwWzCrivAWVse3iPGaV5FNVnEdVSSyoK4vyyM0+9fuVM4kCXSQNREfHaA9FaQuFOdAXob0vQntflPa+CB190cP7opOM1ijMzaamNJ/ZpfmcPruEi0+roqYkn+rSfKpL8qkpjQV1TWk+BbmZ0dccVAp0kRQbHh3nQChCayhMWyhMa2+EtlCYtt4IbaFYUHcNDr/rzxXmZlNbXsDs0nxWza84/DoW3gXMLouFeEl+Dmbq5sgECnSRGTYYHaW5J0xzzxAtvWFaesI094Zpjb/uHIjiE7qoywtzmVtewLyKQlYtqKC2rIDa8gLmlhdQW1bAnPICShXUMoECXeQEjYyN09ITZl/3EHu7h2juHmJ/zxDNPWH2dw/RM6G/Oi87i3kVBdRVFvK+M2qYV1FIXUUhcysKmFteyLyKAory9Kspxy6pb42ZrQH+B5ANPODu35/w/kLgJ0AN0A38qbs3T3OtIikTGRljX/cQew4OsrdrkD0Hh9jXPci+7iFaeyPvGAWSl51FfWUhdZWFrDx7LvWVhcyvLDq8r7o4f8ZmCkpmmzLQzSwbuA+4GmgGNpjZOndvSjjsvwP/5O4PmdmVwPeAT81EwSIzZWzcaekJs+vgALs7B9ndGXve2zVIayjyjmMri3JZWFXM+Qsq+fCqIhbMij0WVhUzu1SBLamRTAt9NbDT3XcDmNkjwA1AYqAvB/48/vp54PHpLFJkOkVGxtjdOcibHf3s6hjgzY4BdnUO8NbBoXdMiCkvzGVxdTEXLaliYVUxi6qLWFxdzMKqYsoLc1N4BiKTSybQ64D9CdvNwIUTjnkN+BixbpmPAKVmVuXuXYkHmdlaYC3AggULjrdmkaSMjI2z5+Ag2w70s+NAP9sO9PNmRz/7uocOX4TMMlhYVcxpNSVccdZsllQXs6SmhCXVxcwqztNFR0kryQT6ZN/oidPGvgH8TzO7Ffgd0AK8a2EGd78fuB+goaEhuXUmRZJwcCBKU2sfW9v6aGrrY/uBfnZ1DjAyFvuaZWcZS6qLWTmvnA+vqmPpnBJOn13C4urijFnnQ4IvmUBvBuYnbNcDrYkHuHsr8FEAMysBPubuoekqUuQQd6e5J8zrLSFebwkdDvGO/ujhY+aVF3DW3DKuOGs2Z84p5czaUpbUKLgl+JIJ9A3AUjNbTKzlfRPwycQDzKwa6Hb3ceAuYiNeRE5YWyjMq/t62dwS4o14iB+atp6TZSydU8p7llazfG4Zy+eVsay2jMrivBRXLZIaUwa6u4+a2R3AU8SGLf7E3beY2T1Ao7uvAy4HvmdmTqzL5UszWLME1NDwKJubQ7y6v5dN+3p4dX8v7X2xlndOlnFmbSlrVtSysq6cs+vKObO2VFPVRRKYT5yidpI0NDR4Y2NjSn62nBo6+iI07u3h5T3dNO7tZmtb/+Hx3AurijhvfgWr5lewakElZym8RQAws43u3jDZe5qOJifN/u4h1u/q4qV4gO/tGgJid5M5b34lX7z8NM5bUMG59RVUleSnuFqR9KNAlxnT2htm/a4u1u/uYv2uLlp6wwBUFefRsKiST120kIZFs1gxr0zLsIpMAwW6TJuh4VFe3N3F73Yc5D92dLLn4CAAFUW5XLS4irXvXcIlp1Vx+uwSje8WmQEKdDlu7s729n7+Y3snv3uzkw17ehgeG6cgN4uLl1Rxy4ULuOS0as6qLdVUeJGTQIEux2RkbJwNe7p5uqmdZ7e209wT60Y5c04pt166iPcuraFhUaUuYIqkgAJdpjQYHeW5bR08u7Wd57d10BcZJT8ni/ecXs2XrjidK86cTW15QarLFMl4CnSZVHh4jOe2dfDE5lae29ZBdHScquI8PrCilquXz+E9S6u1ZrfIKUa/kXJYZGSMF7Z38sTmVn67tYPwyBjVJfnc9F/mc93Zc2lYNEt3bBc5hSnQM5y7s7k5xGMbm1n3Wiuh8AhVxXl89Pw6PnTOPFYvVoiLpAsFeobq6I/w+KYWHtvYzI72AfJzslizspaPnV/PJadVkaNx4SJpR4GeQdyd/9zZxUPr3+K5bR2MjTsXLKzkex89mw+eM5eyAt20QSSdKdAzQH9khF9ubOafXtzL7s5BqorzuP2yJXy8oZ7TakpSXZ6ITBMFeoDt7OjnoT/u5VevNDM4PMaq+RXc+4lzue7suVobXCSAFOgB9Nr+Xu57fidPN7WTl5PF9efO49MXL+Sc+opUlyYiM0iBHhDuzvrdXfzo+V38YedBygtz+epVS/n0xQu1cqFIhlCgpzl357dbO7jvhZ1s2tdLTWk+37zuLD554UJK8vXPK5JJ9Bufxja81c33ntzKK/t6qa8s5G8/vJIbL6jXOioiGUqBnoZ2dvTz/X/bzrNb25lTls/3P3o2N15Qr7HjIhlOgZ5G2vsi3PvMDh5t3E9xXg5/8YEz+dyliynMU4tcRBToaWF4dJwf/343P3zuTcbGnVsvWcwdV57OLN3dXkQSKNBPcRve6uabv3qdNzsGuHZlLXddu4wFVUWpLktETkEK9FNUaGiE7//7Nh5+eR91FYU8+JkGrlo2J9VlicgpTIF+inF31r3WyneeaKJnaIS1713CV69aSrGGIIrIFJQSp5CewWH+8pebeaapnXPry3noc6tZMa881WWJSJpQoJ8iXtzdxdceeZWuwSjf+uAyPnvpYq1DLiLHRIGeYqNj4/zwuZ388Lk3WVhVzK8/cykr69QqF5Fjp0BPodbeMF975FVefqubj55fxz03rNR0fRE5bkqPFHm2qZ2v/8trjI6Nc+8nzuUj59WnuiQRSXMK9JPM3XnwD3v47pNbWTGvjB/efD6Lq4tTXZaIBEBSi3+Y2Roz225mO83szkneX2Bmz5vZJjPbbGbXTX+p6W9s3Pmbf23ib//fVj6wvJbH/uwShbmITJspW+hmlg3cB1wNNAMbzGyduzclHPYt4FF3/19mthx4Elg0A/WmraHhUb7y8Cae3drB7Zct5q5rl5GlUSwiMo2S6XJZDex0990AZvYIcAOQGOgOlMVflwOt01lkuuvoj3DbTxvZ0hrinhtW8OmLF6W6JBEJoGQCvQ7Yn7DdDFw44ZhvA0+b2ZeBYuD9k32Qma0F1gIsWLDgWGtNSzva+/nsP26ge3CYH39a0/dFZOYk04c+Wb+AT9i+Gfipu9cD1wE/M7N3fba73+/uDe7eUFNTc+zVppmm1j4+/r/XMzw2zqNfuFhhLiIzKpkWejMwP2G7nnd3qdwGrAFw9/VmVgBUAx3TUWQ62tU5wKcefImivGwe/cLFzJ+lFRJFZGYl00LfACw1s8VmlgfcBKybcMw+4CoAM1sGFACd01loOtnfPcQtP34JM/j55y9UmIvISTFloLv7KHAH8BSwldholi1mdo+ZXR8/7OvA7Wb2GvAwcKu7T+yWyQgHQhFueeAlwiNj/Oy2C1lSU5LqkkQkQyQ1scjdnyQ2FDFx390Jr5uAS6e3tPTTNRDllgdepGsgys9vv4hlc8um/kMiItNEM0WnSSg8wqcefJmW3jAPfXY1q+ZXpLokEckwuk38NAgPj/HZf3yZNzv6+T+fauDCJVWpLklEMpBa6CfI3fnW42+waX8vP/rk+bzvjOAPxxSRU5Na6Cfo0cb9/PKVZr585VKuPXtuqssRkQymQD8Bb7SE+KvfbOGypdV89aqlqS5HRDKcAv04hcIjfPHnrzCrKI8ffGKVbhcnIimnPvTj4O58419eo7U3zC++cBFVJfmpLklERC304/Hj3+/mmaZ27rpuGRcsnJXqckREAAX6MXtpdxd/9+/buXZlLZ+7dFGqyxEROUyBfgw6+6N8+eFNLJhVxN/feA5m6jcXkVOHAv0YfOeJJnrDI/zolvMpLchNdTkiIu+gQE/S+l1drHutlT9732lao0VETkkK9CSMjI3z1+veoL6ykC9eflqqyxERmZQCPQkP/fEtdrQPcPeHllOQm53qckREJqVAn0JHX4QfPPsml59Zw9XLdQs5ETl1KdCn8L1/28bw6Djf/pMVGtUiIqc0BfpRvLynm19vamHte5ewqLo41eWIiByVAv0IRsfGufs3b1BXUciXrjg91eWIiExJgX4EP3txL9sO9PNXH1pGYZ4uhIrIqU+BPonO/ij/8PQO3ntGDR9YUZvqckREkqJAn8S9z+4gMjrGt/9kuS6EikjaUKBPcHAgymMbm7nxgvksqSlJdTkiIklToE/wf1/cy/DoOLe9Z3GqSxEROSYK9ASRkTF+tn4vV501m9Nnq3UuIulFgZ7g8U0tdA0Oc9tlap2LSPpRoMe5Ow/8YQ8r5pVx8ZKqVJcjInLMFOhxL+zoZGfHAJ+/bLFGtohIWlKgxz34+z3MKcvng2fPS3UpIiLHRYEObG3r4w87D3LrJYvJy9FfiYikp6TSy8zWmNl2M9tpZndO8v69ZvZq/LHDzHqnv9SZ88Dv91CYm80nVy9IdSkiIsctZ6oDzCwbuA+4GmgGNpjZOndvOnSMu/95wvFfBs6bgVpnREdfhHWvtfDJ1QsoL9J9QkUkfSXTQl8N7HT33e4+DDwC3HCU428GHp6O4k6Gh9a/xei48zlNJBKRNJdMoNcB+xO2m+P73sXMFgKLgeeO8P5aM2s0s8bOzs5jrXXaDQ2P8vOX9nHN8jksrNJ65yKS3pIJ9MnG8PkRjr0JeMzdxyZ7093vd/cGd2+oqalJtsYZ88tXWugdGuH2y5akuhQRkROWTKA3A/MTtuuB1iMcexNp1N3ys/VvcW59ORcsrEx1KSIiJyyZQN8ALDWzxWaWRyy01008yMzOBCqB9dNb4szY1TnAjvYBPnJenSYSiUggTBno7j4K3AE8BWwFHnX3LWZ2j5ldn3DozcAj7n6k7phTytNb2gG4RjewEJGAmHLYIoC7Pwk8OWHf3RO2vz19Zc28p5sOcHZdOfMqClNdiojItMjIaZHtfRE27evlAyvmpLoUEZFpk5GB/kyTultEJHgyMtCf2nKAxdXFLNVNLEQkQDIu0EPhEdbv6uKa5XM0ukVEAiXjAv2F7R2MjjvXqP9cRAIm4wL96S3tVJfkc958TSYSkWDJqECPjIzxwvYOrl4+h6wsdbeISLBkVKD/cddBBofHNFxRRAIpowL96S3tlOTncPFpugm0iARPxgT62LjzTFM7V5w1m/yc7FSXIyIy7TIm0F/Z10PX4DDXLFd3i4gEU8YE+lNvHCAvO4vLz0z9OuwiIjMhIwLd3Xm6qZ1LTq+itED3DRWRYMqIQN92oJ993UNcs1xrt4hIcGVEoD+9pR0zeP/y2akuRURkxmREoD+15QDnL6hkdmlBqksREZkxgQ/0gwNRmtr6uGqZWuciEmyBD/TXW0IAWrtFRAIv+IHeHAv0lXVlKa5ERGRmBT/QW0IsqSnWcEURCbzgB3pziHPqylNdhojIjAt0oHf0RTjQF2GlAl1EMkCgA/3QBdFz6itSXImIyMwLdKBvbg5hBivm6YKoiARfoAP9jZYQp9eUUJyfk+pSRERmXGAD3d3Z3BLi7Hr1n4tIZghsoLf3Rensj3K2LoiKSIYIbKBvbu4F4By10EUkQwQ20N9oCZFlsHyuAl1EMkNSgW5ma8xsu5ntNLM7j3DMfzWzJjPbYmb/PL1lHrvNLSHOmFNKYZ7uHyoimWHK4R9mlg3cB1wNNAMbzGyduzclHLMUuAu41N17zCylSxu6O683h7jyLK2wKCKZI5kW+mpgp7vvdvdh4BHghgnH3A7c5+49AO7eMb1lHpvWUISuwWGNcBGRjJJMoNcB+xO2m+P7Ep0BnGFm/2lmL5rZmsk+yMzWmlmjmTV2dnYeX8VJOLTCoka4iEgmSSbQbZJ9PmE7B1gKXA7cDDxgZu+ab+/u97t7g7s31NTUHGutSXu9pZecLGPZXM0QFZHMkUygNwPzE7brgdZJjvmNu4+4+x5gO7GAT4nNzbELogW5uiAqIpkjmUDfACw1s8VmlgfcBKybcMzjwBUAZlZNrAtm93QWmix35/WWkLpbRCTjTBno7j4K3AE8BWwFHnX3LWZ2j5ldHz/sKaDLzJqA54G/cPeumSr6aJp7wvQOjeiCqIhknKRWrXL3J4EnJ+y7O+G1A/8t/kipt5fMVaCLSGYJ3EzRzc0hcrONM2tLU12KiMhJFbhAf72ll7Nqy8jP0QVREcksgQr0QzNEdcs5EclEgQr0fd1D9EVG1X8uIhkpUIG+WTNERSSDBSrQX28JkZeTxRlzdEFURDJPoAJ9c3Mvy2pLycsJ1GmJiCQlMMk3Pu5saenThCIRyViBCfS3ugbpj45yTt271gQTEckIgQn0XZ2DACydU5LiSkREUiMwgd4WCgNQV1GY4kpERFIjMIHe2hshN9uoLslPdSkiIikRmEBvC4WZU1ZAVtZk9+MQEQm+4AR6b4R55epuEZHMFZhAbw2FmVtRkOoyRERSJhCBPj7utPdFmKsWuohksEAE+sGBKCNjzjy10EUkgwUi0FtDEQC10EUkowUi0Nt6Y2PQ55arhS4imSsQgX6ohT5Pk4pEJIMFItDbesPk52RRWZSb6lJERFImGIEeijCvohAzTSoSkcwViEBvDYXVfy4iGS8Qgd7WqzHoIiJpH+ijY+N09Ec0Bl1EMl7aB3p7f5Rx1xh0EZG0D/QD8XXQtY6LiGS6tA/01t74GHS10EUkw6V9oLephS4iAiQZ6Ga2xsy2m9lOM7tzkvdvNbNOM3s1/vj89Jc6udbeCCX5OZQVaFKRiGS2nKkOMLNs4D7gaqAZ2GBm69y9acKhv3D3O2agxqNq0xh0EREguRb6amCnu+9292HgEeCGmS0reW2hCHO1houISFKBXgfsT9huju+b6GNmttnMHjOz+ZN9kJmtNbNGM2vs7Ow8jnLfrbU3wjy10EVEkgr0yRZI8Qnb/woscvdzgGeBhyb7IHe/390b3L2hpqbm2CqdRHR0jIMDUY1BFxEhuUBvBhJb3PVAa+IB7t7l7tH45o+BC6anvKNrD8V+pEa4iIgkF+gbgKVmttjM8oCbgHWJB5jZ3ITN64Gt01fikbXGhyxqDLqISBKjXNx91MzuAJ4CsoGfuPsWM7sHaHT3dcBXzOx6YBToBm6dwZoP0xh0EZG3TRnoAO7+JPDkhH13J7y+C7hrekubmmaJioi8La1niraFwlQU5VKYl53qUkREUi69A13roIuIHJbWgd4a0hh0EZFD0jrQ20JhXRAVEYlL20APD4/ROzSiLhcRkbi0DfTDY9DVQhcRAdI40NviQxbVQhcRiUnbQNcsURGRd0rbQD/UQp9Tnp/iSkRETg3pG+ihMNUl+eTnaFKRiAikcaC3hiK6ICoikiBtA72tV7eeExFJlL6BHtK0fxGRRGkZ6H2REQaio+pyERFJkJaBrjHoIiLvlpaBrlmiIiLvlpaBrha6iMi7pWegh8JkGcwu1aQiEZFD0jLQW3sjzCkrICc7LcsXEZkRaZmIbSGNQRcRmShNAz3C3Ar1n4uIJEq7QHd3WnvDuvWciMgEaRfoPUMjREfHNcJFRGSCtAv01l6NQRcRmUzaBXpbSGPQRUQmk4aBHmuhz1ULXUTkHdIu0GvLCrh6+RyqizWpSEQkUU6qCzhW16yo5ZoVtakuQ0TklJN2LXQREZlcUoFuZmvMbLuZ7TSzO49y3I1m5mbWMH0liohIMqYMdDPLBu4DrgWWAzeb2fJJjisFvgK8NN1FiojI1JJpoa8Gdrr7bncfBh4BbpjkuO8Afw9EprE+ERFJUjKBXgfsT9huju87zMzOA+a7+xNH+yAzW2tmjWbW2NnZeczFiojIkSUT6DbJPj/8plkWcC/w9ak+yN3vd/cGd2+oqalJvkoREZlSMoHeDMxP2K4HWhO2S4GVwAtm9hZwEbBOF0ZFRE6uZAJ9A7DUzBabWR5wE7Du0JvuHnL3andf5O6LgBeB6929cUYqFhGRSU05scjdR83sDuApIBv4ibtvMbN7gEZ3X3f0T5jcxo0bD5rZ3uP5s0A1cPA4/2w6y9Tzhsw9d513ZknmvBce6Q1z9yO9d8oys0Z3z7gunUw9b8jcc9d5Z5YTPW/NFBURCQgFuohIQKRroN+f6gJSJFPPGzL33HXemeWEzjst+9BFROTd0rWFLiIiEyjQRUQCIu0CPdmlfNOdmf3EzDrM7I2EfbPM7BkzezP+XJnKGmeCmc03s+fNbKuZbTGzr8b3B/rczazAzF42s9fi5/038f2Lzeyl+Hn/Ij65L3DMLNvMNpnZE/HtwJ+3mb1lZq+b2atm1hjfd0Lf87QK9GSX8g2InwJrJuy7E/ituy8FfhvfDppR4OvuvozYMhJfiv8bB/3co8CV7n4usApYY2YXAX8H3Bs/7x7gthTWOJO+CmxN2M6U877C3VcljD0/oe95WgU6yS/lm/bc/XdA94TdNwAPxV8/BHz4pBZ1Erh7m7u/En/dT+yXvI6An7vHDMQ3c+MPB64EHovvD9x5A5hZPfBB4IH4tpEB530EJ/Q9T7dAn3Ip34Cb4+5tEAs+YHaK65lRZrYIOI/YTVMCf+7xbodXgQ7gGWAX0Ovuo/FDgvp9/wHwl8B4fLuKzDhvB542s41mtja+74S+5+l2k+ijLuUrwWFmJcAvga+5e1+s0RZs7j4GrDKzCuDXwLLJDju5Vc0sM/sQ0OHuG83s8kO7Jzk0UOcdd6m7t5rZbOAZM9t2oh+Ybi30qZbyDbp2M5sLEH/uSHE9M8LMcomF+c/d/Vfx3Rlx7gDu3gu8QOwaQoWZHWp4BfH7filwfXzp7UeIdbX8gOCfN+7eGn/uIPYf+GpO8HueboF+1KV8M8A64DPx158BfpPCWmZEvP/0QWCru/9DwluBPnczq4m3zDGzQuD9xK4fPA/cGD8scOft7ne5e3186e2bgOfc/RYCft5mVhy/DzNmVgxcA7zBCX7P026mqJldR+x/8ENL+X43xSXNCDN7GLic2HKa7cBfA48DjwILgH3Ax9194oXTtGZm7wF+D7zO232q3yTWjx7Yczezc4hdBMsm1tB61N3vMbMlxFqus4BNwJ+6ezR1lc6ceJfLN9z9Q0E/7/j5/Tq+mQP8s7t/18yqOIHvedoFuoiITC7dulxEROQIFOgiIgGhQBcRCQgFuohIQCjQRUQCQoEuIhIQCnQRkYD4/4uPZ4TkF57aAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(cstate[:, 1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 174,
   "metadata": {},
   "outputs": [],
   "source": [
    "params = get_params(opt_state)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 179,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rc('font', family='serif')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 180,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "dcb46b1394444385abf938114be76535",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAARkAAAEYCAYAAABoTIKyAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO2deXhU1fnHP29CICxhDfu+byKyiFoVsVS0rdSqtYtSiwuodan1h2utVbGKu9bWhYqiuHTDFrWiWDdcWERANhEChCUsCQGyEsjy/v64Z8IYk8kNyWQmM+/neeaZe889957vvTPzzjnvWV5RVQzDMMJFQqQFGIYR25iRMQwjrJiRMQwjrJiRMQwjrJiRMQwjrJiRMQwjrMSFkRGRM0RkpYioiHwkIu1d+g0iki4iOSLytzCU+5aIjHPbr4lIUWA/DGVdKCLPhuG6V4nIehE5ICJdg9LHiMhiEdktIs+ISFMR+dDd44MVrvGky7dYRE46Sh3ni8giV8YnIvJ3ERld2/urpkxfz1REfuSe0YdVHK/4/VsoIhtFZI6INPdx/anuezq75ndRfo1VItLPR766/x6paly8gHGAAo0qpN8JfBKmMlsCErSfDoxz2728x19nZSUCLcJ0H5OBEuCNCum9gNkV0tJc3jEV0mcDvY6y/H5AJtDV7ScAs4BpYf7O+H6m7hl96Pf7B7QB9gB3+rz+nRWfdYi8syteF2h9NPdc2bVq+oqLmkykUNVcdZ9UPZRVqqr5YSzicWCCiFxUTb4dwD+B50SkcR2VPQJIV9UMAFUtAx4F1tfR9SslnM9UVfcDHwNhrY0FlXfAZ746v2czMhUQkaEiMt9VaxeJyNSgY4Emz80i8rqIbBaRC0TkYhF5X0RWi0h/l/dG10S4s5IyWgF/c9sfuleiiCSJyIMi8pl7PSgiSZWU/R8RSRORX7tjw1x1PD2ojKnuGu+LyHsiMsSljwnkdRo/clXpAdU8mtXA3cDjItKhmrzXAKnA76vJ55etwAgROSeQoKprVPVN+MazuVtE3nafw6MikhjI7+51sYh8LCJPBBtAEbnIHXvfvcbX5JnWgkZ4Rjlw/b4issB9Jh+LyHeqOlFEfu++n++LyJsi0sWl/wY4C5jsvleXicjD4jV3J7umXaaIbBGRn7lzPnDf1WuD77mKa20Ukf0ico/L87CIZIvIjVXeZTirm9H04kh19SPgw6BXOkHNJeAE4AS3nQR8BfQPOp4O/Mlt/xivGn+u2/8T8HRVVU2qaS7h/Sj/h1dlTQTeAX5f4fwn3fYYIJ8j1e9xeP/2gbxXAE2Cjn1c4VkcBk5x+08Cz4R4dpPdKxH4HPhn0D3MrpD3Q/d+jivjuKBn0asWn9+fgDJgHXAH0KXC8XQ8wy1AMvAlMNUdu8h9js3c8X8At7tj3wF2A+3d/k8C91TDZzqZmjWXegDzONIETHQaL3X7xwJ7gRS3f2fwswauxTXFXdlzqvreBT4XYLLbvg5YEHTsHGBKFff8jWsBI4EcoKnb7wDMCvXZxWNNZryqjgu88B5iMBuBy0TkU+BdoDNedT2Yd937GqA9nmEAWAX0qYW2i4EX1auylgIvApdUyPN2UFnN8T7kylgHvCEiHwMzgFEVjuer6idB1+pdnTin6VfARBE5v5q884C/4zWbGlV3bR9lXwcMAv4DXAZsFJHxFbL9XT2KgH8BP3fpk4G/qWqher+MV4FfumOXAG+papbb/w/wVBUyqnumfnhPRL7Ea+q9q64JCJwI9AXmuPtdBWQAZ1dxne3AByKyELi+hlpeAcbKEUf+BXiGt1pUdTmwDc8wgWfAQ3aaxKORqY5H8H64Y50RWon3DxhMnnsvAVDV4P3a+CG6AVlB+1kuLZhcV2aR2/9Wea459iYwU1VPxfuxNa3sOo6iyq5TGaq6DvgD8BegbTXZrwM6AjeHyiQijwU1G48LUfYGVb0Nz5C/7HQEsz9oOxvvDwK8Z3hhoAynpyzoWPkzV9USVV1SiUY/z9QP4/H+tP4MPBjU9OyGV9N5N0hnE6BVJVr64xmFG1V1LJ6RqfgdrRJV3QssACaJSBugVFVzanAPc/D+EAP3816ozLX+h4lBxuA1SUrdflI9lr0dr2YUoD1BbfYaMBCvZytQ66nre3gIr6n4KLClqkyqul88n9ZcPGNdVb7rQxUmIicAg1V1tstfKiL/Ae6rkDXY6KUCu9z2drxaQ3nXuoikBh1rH5TeCBiqql9WuHadPVNVLXO+usl4TbDpTkex+2MLaGnOEWMYzAggV1U/r4WWF/F8bDn4rMUE8RIwXUS+B3ylniO+Sqwm823S8PwyiEhnvLZxOMhzZTQTkVtE5ES8ptsk5wROACYBzx/Ftbfi1apOcPtn1YHecpwBnoxnkKvL+1+85skJ1eUNQVNgqvvR4Z7NucDCCvnOF49kvCZAoBo/G7jApSPeWKVngo79IMjo/Azv3ipSp89UVQvxeuyuEs+5vwTYJiLnOY2N8JpulTnk04A2Qc76ilrygGYi0lxEXq5CwhtAJ2Aqnu+vKr51LVXdiefbfBHXvAvJ0TriGtILOAPvnzTg+A04+W7Acxjm4LXZwWv3LwMWAc/h+SvWA991D7XIXWsgsNhdcx7eD249cAB4ALgRz6GYjudDeC3o3FGurJeBL/B8PE3x/pHuBz5zr4eAJJc3uOze7nrqNJzm0os44pS90pX9Jl6NQ/GqyEOC8j5TUXclz+4qd3w9zpEadOwGjjhJm+I5Fw+498ZB+Vrh/VP3OsrPrz3wmLvXD4ClwEycU9TlSQduAebj+coeAxKDjv8f3g/5fbwfb8egY5Pc5/0hni+nJTCsBs/0R0HP8Akf378hQc8lB+/7dgqeT+Ztl+djjjiBp7pydwO/c2nTXdo89zkW4fnzAE5yepYCFwIPO23rgR8G6XoKeDRov7J7/sa1KjyzVX4+v4B32jAaNK7bdbKqfhhhKXGBiHwfOCa4CVoV1lwyDMM3IhJw+E7Cq4lXixkZo8EjIq/h+RceE5Gj6VY2/HO2iKwANjnfTLVErLkkIp2Ae4Dhqnp8JccTgHvxBpz1xBvws9gd+x5wHt5AOFXVu+pNuGEYNSKSXdin4DmtqhoX8VOgpareIiJtgcUiMhhv7MDTeN2Mh0RkroiMV9WQffWGYUSGiBkZVf2XhF724Id4nntUdZ+IFAFD8XoatqrqIZfvU5f3W0bGjdGYCtC8efNRgwYNqrsbMAwDgC+++GKvqrav6ng0D8brwJGRteCNUO2AZ2QqS/8WqjoTr6uT0aNH67Jly8Kj1DDiGBHZGup4NBuZTCAlaL+lS9Mq0g3DiEKiqnfJjSoMVLv+izcQCOeTSQbW4g2a6ikiTVy+k11ewzCikIjVZETkNLyZsJ1F5Ha8UYmT8UYdXok3n2KEiPwBb1r8xeoNZy8UkauAP4lIFt6oQ3P6GkaUEjcjfivzyRQXF7Njxw6KioqqOMuoSHJyMt26dSMpqT7njRrRjIh8oapVrvAXzT6ZsLNjxw5SUlLo1asXIhJpOVGPqpKdnc2OHTvo3bva5WcMA4gyn0x9U1RURLt27czA+EREaNeundX8jBoR10YGMANTQ+x5GTUl7o2MYRjhxYxMBDn11FOZNm0al1xyCa1atWLatGlMmzaNyZMn1+g6s2fP5sABL+LFJ598wsiRI/nwww/rXrBhHAVx7fiNNJdeeimXXHIJa9as4YMPPuChhx4C4Pnna7YY3uzZsxk3bhytW7fmlFNO4dhjw7WYn2HUHDMyjrveWMu6nbnVZ6wBQ7q05A8Th1Z5/JJLKgYi8Fi/fj09evRgypQpLFmyhG7dupU7W2fPns3MmTO59957SU9PZ8GCBaSnp/PYY48xaNAgrrzySgAWLFjA3LlzWb58OS+//DK9evWq03szDL9YcykKuf/++8nMzOSaa65h3rx5XHnlld9oQk2dWh5vjgkTJtCrVy+uv/76cgMD0Lt3b5544gnOPfdc5s6dW5/yDeMbWE3GEarGEQk6duxImzZtADjuuONq7GPp18+LrZ6amkp6enodqzNimcLDJWTmHiIz7xDDuraiaePE6k8KgRmZKKViV3FKSgq5uV5zbtu2bd84lpiYiKqyevVqhgwZUun5hqGq7C8sZueBg+zKKWJ3jnvPLWK3e8/KPUTeoZLyc9689hSO6fqt0E81woxMhDl48CAzZ84kJyeH5557jksvvZRnn32WnJwcHnnkEW644QbAq82UlZUxffp0evbsSU5ODq+++iq/+MUvOOuss5gxYwZFRUVcffXVrFq1ijlz5jBw4EDeeOMN9u/fT1paWnntxohNAkZk+75CtrlXxoGD7Nh/kIz93nZR8TdDJDVKEDq2TKZjyyYM6pTC2P7t6dCyCR1SvLSe7XzHjKuSuJ679NVXXzF48OAIKWq42HOLLDmFxWzam8+mzHy2ZheyJbuA9L0FbM0uJD+oFgLQtnljurZu6r3aNKVL66Z0aZVMl9ZN6dwqmdQWTUhIqF2t1+YuGUYDZV/BYTbsyQt65bM5K5+9+YfL8yQmCN3aNKVXu+aM7tmGHu2a06NtM7q3bUr3Ns1o3iTyP/HIKzCMOEdV2b7vIKszcli3K4d1O3NZtyuXPbmHyvOkJDdiQMcUxg/qSJ/2zenbvgV92jene9tmJCVGdyexGRnDqGcOFB5m+bb9LN96gC93HGB1Rg4HCosBz0fSr0MLTu6byuDOLRnQKYWBHVPo2LJJg3Xmm5ExjDCzK+cgizZls2TzPpZt3cemrALAa+oM6JjCWUM7MaxbK47t2poBnVrQpFHtuoyjDTMyhlHH5BUV82laNgs3ZrFoUzZb9npGpWVyI0b3ast5I7sxqmcbju3WimaNY/8nGPt3aBj1wJa9BSxYu5sPvs5kWfp+SsqUFk0acULvtlx0Qg9O7NOOIZ1b1ronpyFiRibCLF26lJtuuonDhw/zyCOPcOKJJ0ZaEuDN5r7uuut45JFHGDduXKTlRCUb9+Tx1urdzF+zi/W7vSg9gzqlcPmpfRg3sD2jeraJeqdsfRDJhcRDhpoVkVlA36CkY4GRqpouIulAukvPUNWLwq84PIwZM4Zx48aRn58fNQYGsNncVZCdf4h5K3cyd/kO1u7MRQRG92zD788ewlnHdKJr66aRlhh1RMTIiEgzqg81u0BV/+7ytwRmq2q6OzZbVe+sU1Hzb4Hdq+v0knQaBt+fcVSnZmRkcOutt3LMMceQlpbGFVdcwahRo/jPf/7DvHnzGDhwIKtXr+app56iZcuW/OxnP2Pz5s2cccYZfPzxx5x33nm88847JCYmcuyxx7J48WIuvPBCpkyZAsAdd9xBSUkJiYmJpKSkcNNNNwFw3XXXUVxcTJ8+fdixY0edPYqGjKrySdpeXly0lQ/WZ1JSpgzr2oo/TBzCD4d1pkPL5EhLjGqqNTIi0lhVD7vtVCBJVXfVstyTqCbUbMDAOC4DngvaHysiN+EFeZuvqp/VUk/UMW3aNCZOnMiFF15Ieno65557LitWrKBNmzY89thjtGrVikceeYQ5c+Zw9dVXc//993PyySdz1113UVRUxK5duxgxYgS33nor9913H1lZWXz3u99lypQpvPPOOyxevJgFCxYAMG7cOCZMmEBGRgYbN25k/vz5AMybNy+SjyDiFB4u4bXlGcz+LJ20zHxSWzTm0lN6c/7IbgzslFL9BQzAX03mFuBut90YeACYVMtyqwpB+y1EJAE4E3gsWJOqLnU1ouUicraqplVybnks7B49eoRWdJQ1jnCxatUqOnTowLZt21BVOnToQFlZGS1atODuu+8mNTWV5cuXM3Tokdnj/fr1IykpiaSkJFJSUti5cycDBgwAoH379uTl5ZVfu7CwkBkzvHvu3r07WVlZrF27lv79+5dfr0+fPvV4x9FD/qESnv9kC89+soWcg8Uc07UlD18wnLOHd4657uX6oEojIyLHAscBx4nIxS45AWheB+VWFYK2Ms4B3tSgSVaqutS9F4rISrwokt8yMhVjYdeB7nph9erV7Nu3j/Hjx/OjH/0IVaVr164kJCRw+eWX8/jjjzN27FhmzpzJzp07y8+rbLBWZWnDhw9n0aJF3HLLLQC8//779OvXj0OHDvH++++X59u8eXMY7i56KTxcwgufbWXmwk3sLyzme4M7cMVpfRnds02DHQgXDYSqybQBege9A5QCj9ZBueWhZl2T6WTgSReOtkRVg5eomwyUO3ZFZDxek+1tl9QP2FQHmiLCsmXLWLhwIYcPH+aee+4BYMuWLVx77bW8/vrrrF69mt27d3P66acDcNlllzF9+nROP/10vvjii/IZ1rNnz2br1q3lM7kPHTrEnDlzWLVqFcuWLWPNmjXk5OQwd+5czj//fJYuXcqtt95Ko0aNKCoqYsaMGXTv3p358+dz+eWX0717d1SVOXPmMGrUKFJSYrd5oKrMXZ7BjPlfsTf/MKcNaM8NZwxgePfWkZYWE1Q7C1tE+qvqxqD9RBcutnYFi5wB/ATIAopV9S4ReQDYp6ozXJ7jgItU9cag84YBdwJfAF3wepfuq648m4Vdd8TSc9uwJ4/b/72Gpen7GNmjNb/74WBG9WwbaVkNirqYhZ0mIiM40rz5JTCltsJU9V3g3QppN1XYXwmsrJC2Gji/tuUb8U1RcSmP/W8jz368mRbJjZhx3jB+Orp7XA6WCzd+jMwbgAJ73f6w8MkxjPCTvreAq19Zztqdufx0dDdu+f5g2jZvHGlZMYsfI5OpqpcGdlytJmZQVXPq1YCGvsjZf1ft4ua5q0hMEGb9ajTjB3eMtKSYx8+Y50UiEtyXOTxcYuqb5ORksrOzG/wPp75QVbKzs0lObniDz4pLy7hj3hqufmU5/Tu24K3fnGoGpp4I1YW9DzgACHCbiKjbbgnMrhd1YaZbt27s2LGDrKysSEtpMCQnJ9OtW7dIy6gRRcWlXP3yct5bn8mUU3tz01mDbE5RPRKquXSNqr5SMVFEfhJGPfVKUlISvXv3rj6j0WApOFTClBeXsWhzNvf8+Bgmndgz0pLijirNecDAiMjZFdL/FW5RhlEX5BQWM2nWEpZs2ccjPx1uBiZC+HH8PiEiNwTtK97gt7tUNSM8sgyjduQUFvPzvy5mU2Y+T140kjOHdoq0pLjFj5F5FW/i4magDzAG+AS4A7gifNIM4+goLVOu/dsK0jLzmPWr4xk7oH2kJcU1frxfB1X1PVXd4pZiaKyqHwMbqzvRMCLBA2+vZ+GGLO4+5xgzMFGAn5rM8SIyGm8C4gBgjIg0BgaGVZlhHAXzVmbwzMLNTDqxB78YU83Me6Ne8GNk/gA8AwwF1gJXAUOAj8KoyzBqzJqMHG761yrG9GrLHWcPrf4Eo16o1sio6go8PwwAItLZLVq1suqzDKN+ySksZuqLy2jXvDFPThpJ40Y2DiZaCDUYb7iqfhm0lkyAicAF4ZVlGDXj4Xe/ZnduEf/+9cmktmgSaTlGEKHM/XXu/RK89WQCL5sHb0QV63bm8tLirUw6saetAROFVFmTUdXL3OZ1bnkFAETEGrtG1KCq3Pn6Wlo3a8wNZwyItByjEqqsyYhIDxHpAeQEtt1+gw0/YsQer3+5k6Xp+7jxzIG0bmbLNUQjoRy/H+LFNqq4DkIP4LYw6TEM3+QfKuHet75iWNdW/HR090jLMaqgugmSb1VMFJEfhFGPYfjmifc3sif3EE9NGkWirWgXtYSaIPktAxMq3TDqk+37Cnnuky38ZFQ3RvZoE2k5RghsMIHRIHll6TZKy9ScvQ2AUONkbgYeqovIBFVcv7pY2JOBK4EilzRLVee4Y5OAEXghWjap6jPh0GhEJ4dLyvjH59v57qCOdLHY01FPKJ9MkqqWisiVqvp0IFFETlPVWk0p8BkLG+DnQfGvA+d2A6YBI1RVReRzEXk/OGyLEdu8s3Y32QWHmXSizU1qCIQyMgNdbWKCiBQGpU+k9vOWqo2F7bhGRHYDzYA/q+o+vJC1XwRFlFwEfB+bFR43vLxkK93aNGVsf5th3RAI5ZOZgRc8LRBBsi5H/PqJhf0RcL+qPgQsA/5Zg3MBLxa2iCwTkWW2jm9skJaZz+LN+7jwhB4WI6mBEGrE71pgrYi8EYYRv9XGwlbVLUG77wOvi0iiy9evwrnfioPtrtEgY2EbVfPKkm0kJYqNi2lA+Old2iwi94jIGyIyHW+AXm0pj4Xt9k8G/isibUWkJYCI3CciASPYH9jinNDvAKPkSLCkk4D5daDJiHKKikv51xfbOeuYzjYJsgHhZz2ZR/DW9H0e78f+KDC1NoWqaqGIXAX8SUSygFWq+l4gFjZeU2038JSIbMGLWvlLd+4OEXkIeFRESoFnzekbH7zx5U5yi0q46ARz+DYk/BiZTar6QGBHRH5XFwVXFwtbVR8Pce5LwEt1ocNoOLy8ZBv9OrTghN62EEBDwk9zqbvzheCaL13DK8kwvk1aZh4rtx/gwjE9LKxwA8NPTeZdIF1EsvF6lq4OryTD+DYLN+wFYMJQCy3b0PCz/ObrIrIQr0cnTVUPhF+WYXyTzzbtpWe7ZnRr0yzSUowa4mvukqoeUNVlZmCMSFBSWsaSzfv4Tt/USEsxjgKbIGlEPasycsg7VMLJ/dpFWopxFJiRMaKez9I8f8xJfczINERqbGRE5MpwCDGMqvg0LZvBnVvSzgbgNUhCrfG7T0Q2i8h2ESkQkXQ3UfIP9ajPiHOKikv5Ytt+Tu5rtZiGSqiazDWq2gd4AOigqr3wJiL+pT6EGQbAsvT9HC4p4+R+5vRtqIRafvMVt9lRVQtcWj7ehETDqBc+3bSXRgnCGBvl22DxMxhvsIjciLdeywC8+UuGUS98lraX47q3pnkTP19VIxrx4/idArR376nA5WFVZBiOnIPFrM7I4TvWVGrQ+Bnxu09EbgHaAdmqWhZ+WYYBizdnU6aY07eBU21NRkQmAJuBWcAvROSKsKsyDLymUtOkREZYyJMGjZ/m0kRgEPCpqr4M9A2vJMPw+HRTNsf3bkvjRjZmtCHj59PboapFQGD5Spu/ZISdzLwi0jLzrakUA/hx2Q9wPplBInINtp6MUQ+s25kLwHHdW0dYiVFb/NRkrscbG5MKdAJuDqsiwwA27skHYEDHlGpyGtGOn96lPBG5HW/Bqn3Wu2TUBxv25JHaogltmjeOtBSjltSkd+k5rHfJqCc2ZuYzoGOLSMsw6gA/PplA79JvVPVlF1Gg1viIhX0zXvNsNzAKuENV17tj6RwJzZKhqhfVhSYjOlBV0jLzOX+kuf9iAT9GZoeqFolInfUu+YyF3QK4wcW7/hnwIJ7BA5itqnfWVocRnezMKSL/UAn9zR8TE/hx/AZ6l4bUYe9SVbGwy1HV3wfFu04A8oMOjxWRm0Rkuoh8p6pCLExtw2TDHi8KsTl9Y4NI9S7VJJ51Y+BXwO1Bybe4WFD3Ac+JSL/KzlXVmao6WlVHt29vwdkbCmmuZ6l/B/PJxAJ+e5cexBvpu8kt91Bbqo2FDeUG5ingd6q6KUjTUvdeKCIr8cLcVhoP22h4WM9SbOGnd+mXwFq83qV1InJxHZTrJxZ2U+AZ4BFV/UJEznfp40XkrKBr9cMLo2vECBusZymm8OP4/RHQS1UPi0gy8ArwYm0K9RkL+2XgGKC3ixjYHJiLV+O5U0RGAl2Auar6SW30GNGDqpK2J48LRnePtBSjjvBjZFao6mEA18v0OYCIdFXVjKMt2Ecs7POqOG81cP7RlmtENztziig4XEo/88fEDH6MzDEicjfegLw+QAfXZJoIXBBOcUb8YT1LsYef3qXOQCnQ073vAnrjTTMwjDplY7mRsZpMrOCnJnOda6J8AxEZGgY9RpyzYU8+7VOa0LqZ9SzFClUaGWdEWqnqZyLSHC/eUmPgXlXNVNW19SXSiB82Zubb+JgYI1Rz6SG87mGAPwLHAtuAR8MtyohPAj1L5o+JLUI1l75Q1UBX9c+AUaq6U0TuqwddRhySceAgBYdL6W/+mJgiVE2mBEBETsSb6bzTpReFXZURl9hCVbFJqJpMXzchchLwPICIdMMbIGcYdU6g+9p8MrFFqJrMjXjd1v8EnnRp04G3wi3KiE82ZlrPUiwSqiZzEt7ExMOBBFW9JPySjHhl4548Gx8Tg4SqyfQF5ovI8yJypohY8BsjbJSVqeu+Nn9MrFGl4VDVh1R1PF739RjgQxF5UkTG1ps6I27YmXOQQutZikmqrZ2oapqqTlfVsXhLZn5fRH4VfmlGPJG+txCAPqlmZGINP9MKEJGJeAtLrQSmq2phWFUZcUd6dgEAvVKbRViJUdf4WbTqAbyoAmPxphXMCLcoI/7Yml1Ak0YJdExJjrQUo47x48w94HqVNqvqCrxFpQyjTknPLqRnu2YkJEikpRh1jB8jk+reA5EDzP1v1Dlbswvo2a55pGUYYcCPkdkgIuuAi0VkKbAuzJqMOKOsTNmaXUivduaPiUX8RCt4WkQ+AoYCq1X16/DLMuKJPXlFHCops5pMjOKrd0lVvwK+Ai9gmqrOrG3BPsLUJuMtN5EB9AdmqOoGd2wSMAJvpb5NqvpMbfUYkSPQfd3LjExMEmrRqn14IWmFI/4YwevKrpWR8Rmm9npgm6o+ICLDgFnAqW6S5jRghAth+7mIvK+qG2ujyYgcW133dU9rLsUkoXwy16hqH1Xt7d77qGpv4No6KLfaMLVufxGURygY7mIynYm31k3A8C0Cvl8HmowIkZ5dSFKi0KV100hLMcJAqGkFrwCIyGgROd5tnwi8Xgfl+glTW1WemoS4tVjYDYCt2QV0b9OMROu+jkn89C7dCBS47QLgnjoo10+Y2qry+ApxCxYLu6Gw1Y2RMWITP0bmc1VdB+XNll11UG61YWqB/+I1q3A+mS9VNRd4BxglLqykyzO/DjQZEUBVbYxMjOOnd2mAiKSq6l4RScXr6akVPsPUPg48JCK34y1ofpk7d4eIPAQ8KiKlwLPm9G247M0/TMHhUhsjE8P4MTKzgRUikgLkAD+vi4J9hKk9CFxdxbkvAS/VhQ4jspT3LKVaTSZW8TMY7zOge1BtxroAjDojPdvGyMQ6oSKXzBgAABTLSURBVMbJBIzK2KA0gF8CU+pBmxEHbM0uIDFB6Grd1zFLqJrMHLzxJ4/jrSMTYFhYFRlxRXp2IV1bN6VxI1vdNVap0sioamCA23Wq+jGAW+f3pPoQZsQHXs+SOX1jGT9/H4ODtocAl4dJixFnqCpb9haYPybGCeWTaQm0BgaJSA+XnA8cquocw6gJBwqLySsqsZpMjBPKJ3MuMBnoBRyHNzmyBHg77KqMuKB8XV+rycQ0oXwyLwAviMgFwL9VtaT+ZBnxwNZA97UtHh7T+PHJPAEMCrcQI/5Izy5ABLq1MSMTy/gxMi+p6prAjoiMDKMeI47Yll1Il1ZNSU5KjLQUI4z4mVbQWkRm4K2Mp8BE4IKwqjLigvTsAnq0tVpMrOOnJjMKOIjnAO4NtA2nICN+2JpdaP6YOMBPTeYqVV0c2BGRWs/CNozcomKyCw7bEg9xgJ9Y2ItFpLmI9HDjZS6qB11GjLMpMx+A3jb7OuaptiYjIjfgTYpMAfYAXYA7wyvLiHXW7coFYEjnltXkNBo6fnwynVR1BPBXVT0Z+HOYNRlxwLqduaQkN6JbG5t9Hev4MTL57j2wru7AMGkx4oh1u3IZ0rklR1ZRNWIVP0amm4hMBLaLyCagc5g1GTFOaZmyflceQ7pYUyke8LMy3tTAtogsAjaEVZER82zZW8DB4lKGdmkVaSlGPVBtTUZEWovI/SLyJjAJsIENRq0wp2984ae5NAvIBp7DiyQwqzYFurAnM0XkFhGZJSIdK8lzvIi8LCLTROSvIjIl6NjTIvJh0MtW6mtgrNuZS1Ki0K9Di0hLMeoBP4PxvlbVBwI7InJ/Lcu8F/ifqv7D+XoewusiD6Yz8LiqLhWRJCBTRP6tqnuB3ap6ZS01GBFk3a5c+ndIsSU34wQ/n3JeIEKBiDQDdrrtXxxlmeUxrqk8Bjaq+rqqLg1KKgGK3XaKiPxORG4WkWtExI+hNKKIdTtzzekbR/j5gf4GuN0FYesA7BOR6/HCw75a2Qki8g7wrWYQcAffjGWdC7QRkUYh1qu5BrhXVXPc/st4weBKXDC4W4HpVeiYCkwF6NGjR2VZjHomM6+IvfmHGGpGJm7wY2T+qKpPVEx0ESArRVXPrOqYiARiWR/AM1T7qzIwInIh0FxVy+Nvq+ryoCzvAzdThZFR1ZnATIDRo0drVZqM+mPdTnP6xht+5i59y8C49KeOsszyGNe4GNjgRUIIWksYEbkc6KCq94jIMBEZ4NIfDLpWfyDtKHUYEWCtMzKDrSYTN0TCn3EbcL8zGn2BaS79WLxYT8NE5BzgYbzwuD8G2gHX4o3Rae/WtynEG318Qz3rN2rBul25dG/blJbJSZGWYtQToaIVnAYsVNU6bWao6j4qiUCpqitxgeNUdR5Q6UgtVZ1cl3qM+uWrnbnWVIozQjWXfq6qKiJnBSeKSK+wKjJiloJDJWzJLmBIZxvpG0+EMjLFzqBMCKwl43wmv64XZUbMsX53HqpYz1KcEconswivZ6Y/MCIovQdwUzhFGbFJ+XQCMzJxRai4S68Cr4rIRFV9I5AuIj+oF2VGzLFuZy6tmyXRuVVypKUY9YifWdhviMj3gOHASlV9K/yyjFhk3c4cW0MmDvEzC/v3eN3EPYFpbt8wakRJaRnrd+dZz1Ic4mecTGNVLW8iich9YdRjxCjrd+dxqKSMoV3NyMQbfiZIllbYLwuHECO2+WB9JgCn9GsfYSVGfeOnJlMiIq8Dm/FG6C4JryQjFvnf+kyGd29N+5QmkZZi1DN+HL/3iMgEvGH//1XVd8Mvy4glsvIO8eX2A/zfGQMiLcWIAL7mLqnqAmBBmLUYMUqgqfTdwR0irMSIBLY0mRF23lu/h86tkq1nKU4xI2OElaLiUj7euJfvDupg42PiFD/jZO6tDyFGbLJkyz4KD5fyvcGVLZRoxAN+ajLDROQvIvJrEbHo6EaNeO+rPTRNSuSkvu0iLcWIEH4cvz9V1YMiMgT4i4hkA39W1S1h1mY0cFSV977K5OR+qSQnJUZajhEh/NRkxovIKLwFu08CdgHni8htYVVmNHi+3pNHxoGDfM96leIaPzWZl4AVwBPAr1S1DEBEHg+nMKPh895Xrut6kBmZeMaPkfl9YDFxEWkmImV4Uw1WhlWZ0eB576s9HNutFR1a2tIO8Yyf5lLToO3OwCxVLVbV58OkyYgBMnOLWLH9gNVijJALifcAegGDRGSsS04AarWwuIi0BWbgzYXqD9ymqnsqyZcOpLvdDFW9yKX3An6PFwqlF/B/qppfG01G3fPConQAzjmua0R1GJEnVHNpBPBj4DggMIqqFHizlmX6iYUNMFtV76wk/WngDhcn+1q84G62xk0UUXCohJcWb+PMIZ3onWqjHuKdUMtvzgPmicjxqvp5HZb5Q+CPbvtT4IUq8o0VkZvwok3OV9XPRCQJOB34POj8Z6nCyFiY2sjwj2XbyTlYzNTT+kRaihEFhGouiYu5tCc4siNwpaqG7L6uo1jYt7jaSjNguYicDRQAB4NiQeW661WKhamtf0pKy5j1yRZG92zDyB5tIi3HiAJCNZeWAGOAj4AtHGky9cCLAlkldRELW1WXuvdCEVmJF9L2FaBpkAFsCWSG0mLUL2+t2c2O/Qe54+whkZZiRAmhmktj3OZ1dRytIBALezsVYmED3VR1m4iMB5JU9W13Tj9gk6oWi8gHwPHA0uDzjcijqsxcuIk+7ZvbXCWjHD/jZHYF/DIiciKwsJZlVhsLG692cqeIjAS6AHNV9ROX70rgDreQVg8sFnbUsGhTNmsycrnvvGEkJNiMa8PDj5GZBtzttguAe4Drj7ZAn7GwVwPnV3F+OnDp0ZZvhI9nFm4mtUUTzh1h3dbGEfwMxlumquug/Me/K7ySjIbIsvR9fLQhi8nf6WmTIY1v4MfIDBCRVAD33i+8koyGRlFxKTfNXUXX1k255OTekZZjRBl+mkuzgRUi0hKvR+jnYVVkNDj+9N5GNmcV8OKlY2jexNey0UYc4SdawWdAdxFJVdW99aDJaECsycjhmYWbuWBUN8YOsJhKxrfxs/xmJxF5GfhAROaIiPVNGgAUl5Zx079W0bZ5Y27/oY2LMSrHj0/mj8A84Fd485ZmhFWR0WCYuXAz63blMv2cY2jVLCnScowoxU8Der2q/sNtLxcRm5BisGLbfh5/byM/GNaJs47pFGk5RhTjpybTzy3PEOhdMiMT52zNLuCyF5bRqWUy0885JtJyjCjHT03mBeBLEUkBcrDepbhmX8FhJj//OarK7EuOp10Li21thKbGvUsi0rS6c4zYpKi4lMtf+JyMAwd5dcoJ9GnfItKSjAZAqKUexlaSBt4CU9+aFmDENiWlZfz27ytZsf0AT144klE920ZaktFACFWTeRT4kiNLPAQYFj45RjRScKiEa15ZzgdfZ/H7s4fw/WGdIy3JaECEMjLXqeqnFRNF5OQw6jGijKy8Q1z2wuesycjh3nOHceEJtsKgUTNCrSfzKYCINAauAJLwFrLaWD/SjEizOSufXz2/lL15h/nrxaMZb2vEGEeBny7sR4G2eGu37OLIsg9GDLNg7W7Oe+ozCg+V8repJ5qBMY4aP0YmXVXvAnap6mYgI8yajAhSeLiEW19bxdQ5X9C1dVNe+/V3GN69daRlGQ0YP+Nk+ohIE0DdEpn2lxajfLn9ANf/fSXp2QVceVpfbjhjAI0b+fkfMoyqCdWFfayqrgIW4C0krnjhRX5bT9qMeuJA4WEefXcDLy3ZRoeUJrxy+Ymc1LddpGUZMUKomsxTInKbqv7bLd7dD0hT1QP1pM0IMyWlZby6dBsPv7uB3IPFXHRCT6ZNGGiTHY06JZSRmQN0EZFngK+BF+rCwPgJUysi44C/AFkuqQPwD1W9U0SeBgYFZb/WLQtq+KSsTHl77W4e/99Gvt6Tx0l92vGHHw1hUKeWkZZmxCChurCfdpuvikh/4LdudbzXVPXDWpTpJ0ztTmCSqq4AEJFZwPPu2G5VvbIW5cctJaVlvLlqF3/+II20zHz6pDbnqYtGctYxnQKjuQ2jzvG7VmI6sBa4BpiE16V9tFQbplZVNwS23SJZTVR1q0tKEZHfASV40ROeriw4nHGEnMJi/vnFduYs3srW7EIGdkzhiV+M4AfDOpNooUuMMBPK8TsBz+F7BXAxkAb8FfhbdRetozC1AX4NPB20/zKwSlVLROQB4FZgehU64jYWtqqyakcOLy3eyutf7uRQSRmjerbh1u8PZsKQjhYXyag35EhY6QoHRLLwxtG8DPy1rvweIrId+I6qbnf+mTRVrbRm5LrOX1PVH1Zx/CzgZlU9vbpyR48ercuWLauN9AbB9n2FzFuZwb9XZLApq4BmjRP58YiuTDqhJ0O6mM/FqHtE5AtVHV3V8VDNpfnAVFUtqmNN1YapDcp7IfBq8Mki8qCq3uh2++PVsOKardkFvLtuD2+v2c2yrfsBGNOrLZee0puJw7vQMtl6i4zIEcrIXBUGAwP+wtQGuAA4p8L57UVkBlAIDCQOw9SWlJaxcvsBPtqQxbvr9rB+t9f6HNQphRvPHMg5x3WhW5tmEVZpGB5VNpdijYbcXFJVNmXls2jzPj7ZmMVnadnkHSohQeD4Xm2ZMLQTE4Z0pHtbMyxG/VOb5pIRIYpLy1i/K4/l2/azdMs+lmzJZm/+YQC6tm7K2cM7c2r/9pzcN9UGzhlRjxmZCKOqbM0uZHVGDqszcli57QCrMg5QVFwGQJdWyZzavz0n9G7LCX3a0atdMxvTYjQozMjUI0XFpWzck89Xu3P5apf3Wrszl7wir/e+cWICQ7q05MIxPRnRozUjerSma+umZlSMBo0ZmTCQV1TM5qwCNu/NJy0znw17vPet2QWUORdYclICAzumMHF4F4Z1bcWwrq0Y0DHFZj0bMYcZmaMkp7CY7fsL2ZpdSHp2Advc++a9BWTlHSrP1yhB6J3anMGdPYMyqFMKgzu3pEfbZjba1ogLzMhUQlmZsq/wMLsOFJFx4KD32n+QjAOFbN93kO37C8ubOAFSWzShZ7tmjBvQnt7tm9MntQV92zenZ7vmVjsx4hozMo7CoiK2PXw6xaVlHC4pI9Cxn+peIwQaN0qgSaMEmjRPpHErt90okSaNEo7USnLda1NEbqNmxMnwBaMW/PhJSO1fq0uYkXE0TWpEcUITkpISaOGMSeNGiSQneYakUaJ8KzZMbBCbd2XUEVL7WrgZGYckNmLYrR9GWoZhxBzmLDAMI6yYkTEMI6yYkTEMI6yYkTEMI6yYkTEMI6yYkTEMI6yYkTEMI6yYkTEMI6yYkTEMI6yYkTEMI6yYkTEMI6zUu5ERkQQRuUJEMkXkmBD5viciT4rInSLyh6D0tiIyU0RuEZFZLsKkYRhRSiQmSA4HluCFNKkUEWmGFzVyqKoeEpG5IjJeVd/DXyxtwzCihHqvyajqClVdWU22k4CtqhpYYu5TvBjauPdFlaQbhhGFhKUmEyoWtqq+7uMSwfGywVsGqkMlx0LG0g6OhQ3ki8jXPspOBfb6yBcpTF/tiGZ90awNqtbXM9RJYTEyqnpmLS+RCaQE7bd0acHHDrj0/ZUZGKdjJjCzJgWLyLJQgaoijemrHdGsL5q1wdHri6reJRHp7TYXAT1FpInbL4+ZzZFY2hXTDcOIQurd8SsibYCrgVbAVBF5RVUXi0h74BMR6auqhSJyFfAnEckCVjmnL1QdS9swjCik3o2Mqu4H7nGv4PQsoGvQ/rvAu5Wcvw+YEkaJNWpeRQDTVzuiWV80a4Oj1CdqK9YbhhFGosonYxhG7GFGxjCMsGIhURwi8j3gPLwuclXVuyIsCRHphOe7Gq6qx7u0ZLxRzhlAf2CGqm6IgLa+TttyoBuQrap3i0hbYAaw2em7TVX3REBfAvAG3ujyxnidBJcCTaNBn9PY1OlboKrTouWzDdK3GChyu6WqOv6oPl9VjfsX0AxIA5q4/bnA+CjQ9RNgIrAsKO0W4Ca3PQz4OELajgfOCdpfB4zCmw7yU5c2EZgTIX0JwO1B+/OAi6JFnyv/YeAF4KFo+myD9N1ZSVqNn581lzxCTWOIGKr6L7458hmCplWo6mpguIi0jIC2z1V1XlBSAlBAlEz7UNUyVb0HQEQa4dW2vo4WfSLyS1f+lqDkqPhsgxgmIje7ScpHPa3HmkseoaYxRBtVac2NjBwQkXOBd1R1vYj4nvZRT9rOBH4LvKmqy6JBn4gMAQar6m0icmzQoWj7bO9X1aUikggsFJE8ajCtJ4DVZDxCTWOINqJKq4icDpyO90OGb+oLOe2jPlDVd1T1LKC3iPw6SvSdCxSJyC3AKcAYEbmeKPtsVXWpey8FPsb7nGv8/Kwm41E+jcE1mU4GnoywpqoITKv4WESGAV+qakT+6VwV+lTgN0BnEekZpG87EZz24WoLvVU1UP4WoE806FPVPwbpTAZaqOpjbjtaPttBwMmqOssl9Qde4yienw3Gc4jIGXiO1iygWKOjd+k04GLgLOApPEcheD0Qu4B+wL0amd6lUcBHwDKX1Bz4C/A6cD+wFa9H5xaNTO9SX+BBvN6vJGAwcB1wOBr0OY3n402xaYz37P5DFHy2TlsXp2k5Xo0lCbgBaE0Nn58ZGcMwwor5ZAzDCCtmZAzDCCtmZAzDCCtmZAzDCCtmZAzDCCs2TsYIOyLyMd5EwHZ4k1D/6g51xevh/HmktBnhx7qwjbAjIpeo6vMumN+bqtorkA7MVvsSxjTWXDLCjqo+X8WhFNwEQRG5RER2i8iNIjJHROaLyE9dlNCFgYmCIjJURF50+WaJSJ/6ug/j6DAjY0QMVf1T0PbzwHpguar+EjgEpKjqZcAK4AyX9VngaVV9EJjDkVHQRpRiPhkj2tjk3g8Ebe/nyKS8Y4EJIjIWbwGq/PqVZ9QUMzJGQ+NL4DVVXeXicp0baUFGaMzIGPWCW2pyKtBKRC5V1efc0gutROQXeOFPewKTReR1vBrLL0VkJzAWbwGl+cBlwP+JyBagO/BSJO7H8I/1LhmGEVbM8WsYRlgxI2MYRlgxI2MYRlgxI2MYRlgxI2MYRlgxI2MYRlgxI2MYRlj5f66cL8BnOWD8AAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 288x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "fig, ax = plt.subplots(1, 1, figsize=(4*1, 4*1), sharex=True, sharey=True)\n",
    "ax_idx = [(i, j) for i in range(1) for j in range(1)]\n",
    "\n",
    "for i in tqdm(range(1)):\n",
    "    \n",
    "    ci = ax_idx[i]\n",
    "    \n",
    "    cstate, cconditionals, ctarget = gen_data((i+4)*(i+1), 1, 50)\n",
    "    \n",
    "    runner = jax.jit(jax.vmap(\n",
    "        partial(\n",
    "            hamiltonian_eom,\n",
    "            learned_dynamics(params, nn_forward_fn)), (0, 0), 0))\n",
    "\n",
    "    @jax.jit\n",
    "    def odefunc_learned(y, t):\n",
    "        return jnp.concatenate((runner(y[None, :2], y[None, [2]])[0], jnp.zeros(1)))\n",
    "\n",
    "    yt_learned = odeint(\n",
    "        odefunc_learned,\n",
    "        jnp.concatenate([cstate[0], cconditionals[0]]),\n",
    "        np.linspace(0, 1, 50),\n",
    "        mxsteps=100)\n",
    "    \n",
    "    cax = ax#[ci[0], ci[1]]\n",
    "    cax.plot(cstate[:, 1], label='Truth')\n",
    "    cax.plot(yt_learned[:, 1], label='Learned')\n",
    "    cax.legend()\n",
    "    if ci[1] == 0:\n",
    "        cax.set_ylabel('Velocity of particle/Speed of light')\n",
    "    if ci[0] == 0:\n",
    "        cax.set_xlabel('Time')\n",
    "        \n",
    "    cax.set_ylim(-1, 1)\n",
    "    \n",
    "plt.title(\"Hamiltonian NN - Special Relativity\")\n",
    "plt.tight_layout()\n",
    "plt.savefig('sr_hnn.png', dpi=150)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "main2",
   "language": "python",
   "name": "main2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
